1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 06:06:07 +00:00

Merge branch 'rocm-main' into ci-upstream-sync-98_1

This commit is contained in:
Charles Hofer 2025-01-28 21:18:47 +00:00
commit 47580efda5
15 changed files with 527 additions and 203 deletions

@ -1,4 +1,4 @@
name: CI name: ROCm CPU CI
# We test all supported Python versions as follows: # We test all supported Python versions as follows:
# - 3.10 : Documentation build # - 3.10 : Documentation build
@ -11,10 +11,10 @@ on:
# but only for the main branch # but only for the main branch
push: push:
branches: branches:
- main - rocm-main
pull_request: pull_request:
branches: branches:
- main - rocm-main
permissions: permissions:
contents: read # to fetch code contents: read # to fetch code
@ -42,12 +42,8 @@ jobs:
- run: pre-commit run --show-diff-on-failure --color=always --all-files - run: pre-commit run --show-diff-on-failure --color=always --all-files
build: build:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
runs-on: linux-x86-n2-32 runs-on: ROCM-Ubuntu
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 60 timeout-minutes: 60
strategy: strategy:
matrix: matrix:
@ -65,10 +61,6 @@ jobs:
num_generated_cases: 1 num_generated_cases: 1
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
@ -109,7 +101,7 @@ jobs:
documentation: documentation:
name: Documentation - test code snippets name: Documentation - test code snippets
runs-on: ubuntu-latest runs-on: ROCM-Ubuntu
timeout-minutes: 10 timeout-minutes: 10
strategy: strategy:
matrix: matrix:
@ -146,19 +138,13 @@ jobs:
documentation_render: documentation_render:
name: Documentation - render documentation name: Documentation - render documentation
runs-on: linux-x86-n2-16 runs-on: ubuntu-latest
container: timeout-minutes: 20
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 10
strategy: strategy:
matrix: matrix:
python-version: ['3.10'] python-version: ['3.10']
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev libsqlite3-dev
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with: with:
@ -229,9 +215,7 @@ jobs:
ffi: ffi:
name: FFI example name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu runs-on: ROCM-Ubuntu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -250,7 +234,7 @@ jobs:
path: ${{ steps.pip-cache.outputs.dir }} path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
- name: Install JAX - name: Install JAX
run: pip install .[cuda12] run: pip install .
- name: Build and install example project - name: Build and install example project
run: python -m pip install -v ./examples/ffi[test] run: python -m pip install -v ./examples/ffi[test]
env: env:
@ -259,10 +243,11 @@ jobs:
# a different toolchain. GCC is the default compiler on the # a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be # 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear. # clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
- name: Run CPU tests - name: Run CPU tests
run: python -m pytest examples/ffi/tests run: python -m pytest examples/ffi/tests
env: env:
JAX_PLATFORM_NAME: cpu JAX_PLATFORM_NAME: cpu
- name: Run GPU tests - name: Run GPU tests
run: python -m pytest examples/ffi/tests run: python -m pytest examples/ffi/tests

63
.github/workflows/rocm-ci.yml vendored Normal file

@ -0,0 +1,63 @@
name: ROCm GPU CI
on:
# Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push:
branches:
- rocm-main
pull_request:
branches:
- rocm-main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-jax-in-docker: # strategy and matrix come here
runs-on: mi-250
env:
BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
steps:
- name: Clean up old runs
run: |
ls
# Make sure that we own all of the files so that we have permissions to delete them
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
# Remove any old work directories from this machine
rm -rf workdir_*
ls
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: ${{ env.WORKSPACE_DIR }}
- name: Build JAX
run: |
pushd $WORKSPACE_DIR
python3 build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--base-docker $BASE_IMAGE \
--python-versions $PYTHON_VERSION \
--compiler=clang \
dist_docker \
--image-tag $TEST_IMAGE
- name: Archive jax wheels
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ./dist/*.whl
- name: Run tests
run: |
cd $WORKSPACE_DIR
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"

@ -0,0 +1,52 @@
# Pulls the latest changes from upstream into main and opens a PR to merge
# them into rocm-main branch.
name: ROCm Nightly Upstream Sync
on:
workflow_dispatch:
schedule:
- cron: '0 6 * * 1-5'
permissions:
contents: write
pull-requests: write
env:
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
jobs:
sync-main:
runs-on: ubuntu-latest
steps:
- run: |
gh auth status
gh repo sync rocm/jax -b main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
create-sync-branch:
needs: sync-main
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Create branch
run: |
git fetch
git checkout origin/main
git checkout -b $SYNC_BRANCH_NAME
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
git config --global user.email "github-actions@github.com"
git config --global user.name "GitHub Actions"
git merge origin/rocm-main || true
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
git merge --abort || true
git push origin HEAD
open-sync-pr:
needs: create-sync-branch
runs-on: ubuntu-latest
steps:
- run: |
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}

@ -0,0 +1,41 @@
name: ROCm Open Upstream PR
on:
pull_request:
types: [ labeled ]
branches: [ rocm-main ]
jobs:
open-upstream:
if: ${{ github.event.label.name == 'open-upstream' }}
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Rebase code to main
run: |
git config --global user.email "github-actions@github.com"
git config --global user.name "Github Actions"
git fetch
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
git rebase --onto origin/main origin/rocm-main
# Force push here so that we don't run into conflicts with the origin branch
git push origin HEAD --force
- name: Leave link to create PR
env:
GH_TOKEN: ${{ github.token }}
run: |
# Bash is not friendly with newline characters, so make our own
NL=$'\n'
# Encode the PR title and body for passing as URL get parameters
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
# Create a link to the that will open up a new PR form to upstream and autofill the fields
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
# Add a comment with the link to the PR
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"

@ -22,7 +22,7 @@ on:
jobs: jobs:
upstream-dev: upstream-dev:
runs-on: ubuntu-20.04-16core runs-on: ROCM-Ubuntu
permissions: permissions:
contents: read contents: read
issues: write # for failed-build-issue issues: write # for failed-build-issue

@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
### Simplified Build Script ### Simplified Build Script
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script. For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.

@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB
ARG ROCM_BUILD_NUM ARG ROCM_BUILD_NUM
# Install system GCC and C++ libraries. # Install system GCC and C++ libraries.
RUN yum install -y gcc-c++.x86_64 # (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
# to find them.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
RUN --mount=type=cache,target=/var/cache/dnf \ RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
# Stop git from erroring out when we don't own the repo
RUN git config --global --add safe.directory '*'

@ -21,11 +21,15 @@
import argparse import argparse
import logging
import os import os
import subprocess import subprocess
import sys import sys
LOG = logging.getLogger("ci_build")
def image_by_name(name): def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd) out = subprocess.check_output(cmd)
@ -33,6 +37,25 @@ def image_by_name(name):
return image_id return image_id
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image_name,
".",
]
LOG.info("Creating manylinux build image. Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
return image_name
def dist_wheels( def dist_wheels(
rocm_version, rocm_version,
python_versions, python_versions,
@ -41,34 +64,13 @@ def dist_wheels(
rocm_build_num="", rocm_build_num="",
compiler="gcc", compiler="gcc",
): ):
# We want to make sure the wheels we build are manylinux compliant. We'll
# do the build in a container. Build the image for this.
image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num)
if xla_path: if xla_path:
xla_path = os.path.abspath(xla_path) xla_path = os.path.abspath(xla_path)
# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
# Try removing the Docker image.
try:
subprocess.run(["docker", "rmi", image], check=True)
print(f"Image {image} removed successfully.")
except subprocess.CalledProcessError as e:
print(f"Failed to remove Docker image {image}: {e}")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image,
".",
]
if not image_by_name(image):
_ = subprocess.run(cmd, check=True)
# use image to build JAX/jaxlib wheels # use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True) os.makedirs("wheelhouse", exist_ok=True)
@ -114,13 +116,14 @@ def dist_wheels(
[ [
"--init", "--init",
"--rm", "--rm",
image, image_name,
"bash", "bash",
"-c", "-c",
" ".join(bw_cmd), " ".join(bw_cmd),
] ]
) )
LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True) _ = subprocess.run(cmd, check=True)
@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path):
jax_version = subprocess.check_output(cmd, env=env) jax_version = subprocess.check_output(cmd, env=env)
def safe_decode(x):
if isinstance(x, str):
return x
else:
return x.decode("utf8")
return { return {
"jax_version": jax_version.decode("utf8").strip(), "jax_version": safe_decode(jax_version).strip(),
"jax_commit": jax_commit.decode("utf8").strip(), "jax_commit": safe_decode(jax_commit).strip(),
"xla_commit": xla_commit.decode("utf8").strip(), "xla_commit": safe_decode(xla_commit).strip(),
} }
@ -193,7 +202,7 @@ def dist_docker(
subprocess.check_call(cmd) subprocess.check_call(cmd)
def test(image_name): def test(image_name, test_cmd):
"""Run unit tests like CI would inside a JAX image.""" """Run unit tests like CI would inside a JAX image."""
gpu_args = [ gpu_args = [
@ -211,10 +220,12 @@ def test(image_name):
cmd = [ cmd = [
"docker", "docker",
"run", "run",
"-it",
"--rm", "--rm",
] ]
if os.isatty(sys.stdout.fileno()):
cmd.append("-it")
# NOTE(mrodden): we need jax source dir for the unit test code only, # NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels # JAX and jaxlib are already installed from wheels
mounts = [ mounts = [
@ -225,7 +236,7 @@ def test(image_name):
cmd.extend(mounts) cmd.extend(mounts)
cmd.extend(gpu_args) cmd.extend(gpu_args)
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" container_cmd = "cd /jax && " + test_cmd
cmd.append(image_name) cmd.append(image_name)
cmd.extend( cmd.extend(
[ [
@ -288,6 +299,7 @@ def parse_args():
testp = subp.add_parser("test") testp = subp.add_parser("test")
testp.add_argument("image_name") testp.add_argument("image_name")
testp.add_argument("--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh")
ddp = subp.add_parser("dist_docker") ddp = subp.add_parser("dist_docker")
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
@ -298,6 +310,7 @@ def parse_args():
def main(): def main():
logging.basicConfig(level=logging.INFO)
args = parse_args() args = parse_args()
if args.action == "dist_wheels": if args.action == "dist_wheels":
@ -310,7 +323,7 @@ def main():
) )
elif args.action == "test": elif args.action == "test":
test(args.image_name) test(args.image_name, args.test_cmd)
elif args.action == "dist_docker": elif args.action == "dist_docker":
dist_wheels( dist_wheels(

@ -25,179 +25,205 @@ GPU_LOCK = threading.Lock()
LAST_CODE = 0 LAST_CODE = 0
base_dir = "./logs" base_dir = "./logs"
def extract_filename(path): def extract_filename(path):
base_name = os.path.basename(path) base_name = os.path.basename(path)
file_name, _ = os.path.splitext(base_name) file_name, _ = os.path.splitext(base_name)
return file_name return file_name
def combine_json_reports(): def combine_json_reports():
all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')] all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")]
combined_data = [] combined_data = []
for json_file in all_json_files: for json_file in all_json_files:
with open(os.path.join(base_dir, json_file), 'r') as infile: with open(os.path.join(base_dir, json_file), "r") as infile:
data = json.load(infile) data = json.load(infile)
combined_data.append(data) combined_data.append(data)
combined_json_file = f"{base_dir}/final_compiled_report.json" combined_json_file = f"{base_dir}/final_compiled_report.json"
with open(combined_json_file, 'w') as outfile: with open(combined_json_file, "w") as outfile:
json.dump(combined_data, outfile, indent=4) json.dump(combined_data, outfile, indent=4)
def combine_csv_reports(): def combine_csv_reports():
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')] all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
combined_csv_file = f"{base_dir}/final_compiled_report.csv" combined_csv_file = f"{base_dir}/final_compiled_report.csv"
with open(combined_csv_file, mode='w', newline='') as outfile: with open(combined_csv_file, mode="w", newline="") as outfile:
csv_writer = csv.writer(outfile) csv_writer = csv.writer(outfile)
for i, csv_file in enumerate(all_csv_files): for i, csv_file in enumerate(all_csv_files):
with open(os.path.join(base_dir, csv_file), mode='r') as infile: with open(os.path.join(base_dir, csv_file), mode="r") as infile:
csv_reader = csv.reader(infile) csv_reader = csv.reader(infile)
if i == 0: if i == 0:
# write headers only once # write headers only once
csv_writer.writerow(next(csv_reader)) csv_writer.writerow(next(csv_reader))
for row in csv_reader: for row in csv_reader:
csv_writer.writerow(row) csv_writer.writerow(row)
def generate_final_report(shell=False, env_vars={}): def generate_final_report(shell=False, env_vars={}):
env = os.environ env = os.environ
env = {**env, **env_vars} env = {**env, **env_vars}
cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html'] cmd = [
result = subprocess.run(cmd, "pytest_html_merger",
shell=shell, "-i",
capture_output=True, f"{base_dir}",
env=env) "-o",
if result.returncode != 0: f"{base_dir}/final_compiled_report.html",
print("FAILED - {}".format(" ".join(cmd))) ]
print(result.stderr.decode()) result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
# Generate json reports. # Generate json reports.
combine_json_reports() combine_json_reports()
# Generate csv reports. # Generate csv reports.
combine_csv_reports() combine_csv_reports()
def run_shell_command(cmd, shell=False, env_vars={}): def run_shell_command(cmd, shell=False, env_vars={}):
env = os.environ env = os.environ
env = {**env, **env_vars} env = {**env, **env_vars}
result = subprocess.run(cmd, result = subprocess.run(cmd, shell=shell, capture_output=True, env=env)
shell=shell, if result.returncode != 0:
capture_output=True, print("FAILED - {}".format(" ".join(cmd)))
env=env) print(result.stderr.decode())
if result.returncode != 0:
print("FAILED - {}".format(" ".join(cmd)))
print(result.stderr.decode())
return result.returncode, result.stderr.decode(), result.stdout.decode() return result.returncode, result.stderr.decode(), result.stdout.decode()
def parse_test_log(log_file): def parse_test_log(log_file):
"""Parses the test module log file to extract test modules and functions.""" """Parses the test module log file to extract test modules and functions."""
test_files = set() test_files = set()
with open(log_file, "r") as f: with open(log_file, "r") as f:
for line in f: for line in f:
report = json.loads(line) report = json.loads(line)
if "nodeid" in report: if "nodeid" in report:
module = report["nodeid"].split("::")[0] module = report["nodeid"].split("::")[0]
if module and ".py" in module: if module and ".py" in module:
test_files.add(os.path.abspath(module)) test_files.add(os.path.abspath(module))
return test_files return test_files
def collect_testmodules(): def collect_testmodules():
log_file = f"{base_dir}/collect_module_log.jsonl" log_file = f"{base_dir}/collect_module_log.jsonl"
return_code, stderr, stdout = run_shell_command( return_code, stderr, stdout = run_shell_command(
["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"]) [
if return_code != 0: "python3",
print("Test module discovery failed.") "-m",
print("STDOUT:", stdout) "pytest",
print("STDERR:", stderr) "--collect-only",
exit(return_code) "tests",
print("---------- collected test modules ----------") f"--report-log={log_file}",
test_files = parse_test_log(log_file) ]
print("Found %d test modules." % (len(test_files))) )
print("--------------------------------------------") if return_code != 0:
print("\n".join(test_files)) print("Test module discovery failed.")
return test_files print("STDOUT:", stdout)
print("STDERR:", stderr)
exit(return_code)
print("---------- collected test modules ----------")
test_files = parse_test_log(log_file)
print("Found %d test modules." % (len(test_files)))
print("--------------------------------------------")
print("\n".join(test_files))
return test_files
def run_test(testmodule, gpu_tokens, continue_on_fail): def run_test(testmodule, gpu_tokens, continue_on_fail):
global LAST_CODE global LAST_CODE
with GPU_LOCK: with GPU_LOCK:
if LAST_CODE != 0: if LAST_CODE != 0:
return return
target_gpu = gpu_tokens.pop() target_gpu = gpu_tokens.pop()
env_vars = { env_vars = {
"HIP_VISIBLE_DEVICES": str(target_gpu), "HIP_VISIBLE_DEVICES": str(target_gpu),
"XLA_PYTHON_CLIENT_ALLOCATOR": "default", "XLA_PYTHON_CLIENT_ALLOCATOR": "default",
} }
testfile = extract_filename(testmodule) testfile = extract_filename(testmodule)
if continue_on_fail: if continue_on_fail:
cmd = ["python3", "-m", "pytest", cmd = [
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", "python3",
f"--csv={base_dir}/{testfile}_log.csv", "-m",
"--csv-columns", "id,module,name,file,status,duration", "pytest",
f"--html={base_dir}/{testfile}_log.html", "--json-report",
"--reruns", "3", "-v", testmodule] f"--json-report-file={base_dir}/{testfile}_log.json",
else: f"--csv={base_dir}/{testfile}_log.csv",
cmd = ["python3", "-m", "pytest", "--csv-columns",
"--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", "id,module,name,file,status,duration",
f"--csv={base_dir}/{testfile}_log.csv", f"--html={base_dir}/{testfile}_log.html",
"--csv-columns", "id,module,name,file,status,duration", "--reruns",
f"--html={base_dir}/{testfile}_log.html", "3",
"--reruns", "3", "-x", "-v", testmodule] "-v",
testmodule,
]
else:
cmd = [
"python3",
"-m",
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
"-x",
"-v",
testmodule,
]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK: with GPU_LOCK:
gpu_tokens.append(target_gpu) gpu_tokens.append(target_gpu)
if LAST_CODE == 0: if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu)) print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout) print(stdout)
print(stderr) print(stderr)
if continue_on_fail == False: if continue_on_fail == False:
LAST_CODE = return_code LAST_CODE = return_code
def run_parallel(all_testmodules, p, c): def run_parallel(all_testmodules, p, c):
print(f"Running tests with parallelism = {p}") print(f"Running tests with parallelism = {p}")
available_gpu_tokens = list(range(p)) available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p) executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules. # walking through test modules.
for testmodule in all_testmodules: for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens, c) executor.submit(run_test, testmodule, available_gpu_tokens, c)
# waiting for all modules to finish. # waiting for all modules to finish.
executor.shutdown(wait=True) executor.shutdown(wait=True)
def find_num_gpus(): def find_num_gpus():
cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True) _, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout) return int(stdout)
def main(args): def main(args):
all_testmodules = collect_testmodules() all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel, args.continue_on_fail) run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
generate_final_report() generate_final_report()
exit(LAST_CODE) exit(LAST_CODE)
if __name__ == '__main__': if __name__ == "__main__":
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so" os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-p", parser.add_argument(
"--parallel", "-p", "--parallel", type=int, help="number of tests to run in parallel"
type=int, )
help="number of tests to run in parallel") parser.add_argument(
parser.add_argument("-c", "-c", "--continue_on_fail", action="store_true", help="continue on failure"
"--continue_on_fail", )
action='store_true', args = parser.parse_args()
help="continue on failure") if args.continue_on_fail:
args = parser.parse_args() print("continue on fail is set")
if args.continue_on_fail: if args.parallel is None:
print("continue on fail is set") sys_gpu_count = find_num_gpus()
if args.parallel is None: args.parallel = sys_gpu_count
sys_gpu_count = find_num_gpus() print("%d GPUs detected." % sys_gpu_count)
args.parallel = sys_gpu_count
print("%d GPUs detected." % sys_gpu_count)
main(args) main(args)

@ -116,6 +116,7 @@ def build_jaxlib_wheel(
if compiler == "clang": if compiler == "clang":
clang_path = find_clang_path() clang_path = find_clang_path()
if clang_path: if clang_path:
LOG.info("Found clang at path: %s", clang_path)
cmd.append("--clang_path=%s" % clang_path) cmd.append("--clang_path=%s" % clang_path)
else: else:
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*") raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
@ -315,6 +316,21 @@ def main():
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir) shutil.copy(whl, wheelhouse_dir)
# Delete the 'dist' directory since it causes permissions issues
logging.info('Deleting dist, egg-info and cache directory')
shutil.rmtree(os.path.join(args.jax_path, "dist"))
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))
# Make the wheels deleteable by the runner
whl_house = os.path.join(args.jax_path, "wheelhouse")
logging.info("Changing permissions for %s" % whl_house)
mode = 0o664
for item in os.listdir(whl_house):
whl_path = os.path.join(whl_house, item)
if os.path.isfile(whl_path):
os.chmod(whl_path, mode)
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

@ -0,0 +1,53 @@
#!/bin/bash
# Check for user-supplied arguments.
if [[ $# -lt 2 ]]; then
echo "Usage: $0 <jax_home_directory> <version>"
exit 1
fi
# Set JAX_HOME and RELEASE_VERSION from user arguments.
JAX_HOME=$1
RELEASE_VERSION=$2
WHEELHOUSE="$JAX_HOME/wheelhouse"
# Projects to upload separately to PyPI.
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
# PyPI API Token.
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
# Ensure the specified JAX_HOME and wheelhouse directories exists.
if [[ ! -d "$JAX_HOME" ]]; then
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
exit 1
fi
if [[ ! -d "$WHEELHOUSE" ]]; then
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
exit 1
fi
upload_and_release_project() {
local project=$1
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
if [[ ${#wheels[@]} -eq 0 ]]; then
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
return
fi
echo "Found wheels for $project: ${wheels[*]}"
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
for wheel in "${wheels[@]}"; do
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
done
}
# Install twine if not already installed.
python -m pip install --upgrade twine
# Iterate over each project and upload its wheels.
for project in "${PROJECTS[@]}"; do
upload_and_release_project $project
done

@ -0,0 +1,65 @@
# ROCm CI Dev Guide
This guide lays out how to do some dev operations, what branches live in this repo, and what CI workflows live in this repo.
# Quick Tips
1. Always use "Squash and Merge" when merging PRs into `rocm-main` (unless you're merging the daily sync from upstream).
2. When submitting a PR to `rocm-main`, make sure that your feature branch started from `rocm-main`. When you started working on your feature, did you do `git checkout rocm-main && git checkout -b <my feature branch>`?
3. Always fill out your PR's description with an explanation of why we need this change to give context to your fellow devs (and your future self).
4. In the PR description, link to the story or GitHub issue that this PR solves.
# Processes
## Making a Change
1. Clone `rocm/jax` and check out the `rocm-main` branch.
2. Create a new feature branch with `git checkout -b <my feature name>`.
3. Make your changes on the feature branch and test them locally.
4. Push your changes to a new feature branch in `rocm/jax` by running
`git push orgin HEAD`.
5. Open a PR from your new feature branch into `rocm-main` with a nice description telling your
team members what the change is for. Bonus points if you can link it to an issue or story.
6. Add reviewers, wait for approval, and make sure CI passes.
7. Depending on if your specific change, either:
a. If this is a normal, run-of-the-mill change that we want to put upstream, add the
`open-upstream` label to your PR and close your PR. In a few minutes, Actions will
comment on your PR with a link that lets you open a new PR into upstream. The link will
autofill some PR info, and the new PR be created on a new branch that has the same name
as your old feature branch, but with the `-upstream` suffix appended to the end of it.
If upstream reviewers request some changes to the new PR before merging, you can add
or modify commits on the new `-upstream` feature branch.
b. If this is an urgent change that we want in `rocm-main` right now but also want upstream,
add the `open-upstream` label, merge your PR, and then follow the link that
c. If this is a change that we only want to keep in `rocm/jax` and not push into upstream,
squash and merge your PR.
If you submitted your PR upstream with `open-upstream`, you should see your change in `rocm-main`
the next time the `ROCm Nightly Upstream Sync` workflow is run and the PR that it creates is
merged.
When using the `open-upstream` label to move changes to upstream, it's best to put the label on the PR when you either close or merge the PR. The GitHub Actions workflow that handles the `open-upstream` label uses `git rebase --onto` to set up the changes destined for upstream. Adding the label and creating this branch long after the PR has been merged or closed can cause merge conflicts with new upstream code and cause the workflow to fail. Adding the label right after creating your PR means that 1) any changes you make to your downstream PR while it is in review won't make it to upstream, and it is up to you to cherry-pick those changes into the upstream branch or remove and re-add the `open-upstream` label to get the Actions workflow to do it for you, and 2) that you're proposing changes to upstream that the rest of the AMD team might still have comments on.
## Daily Upstream Sync
Every day, GitHub Actions will attempt to run the `ROCm Nightly Upstream Sync` workflow. This job
normally does this on its own, but requires a developer to intervene if there's a merge conflict
or if the PR fails CI. Devs should fix or resolve issues with the merge by adding commits to the
PR's branch.
# Branches
* `rocm-main` - the default "trunk" branch for this repo. Should only be changed submitting PRs to it from feature branches created by devs.
* `main` - a copy of `jax-ml/jax:main`. This branch is "read-only" and should only be changed by GitHub Actions.
# CI Workflows
We use GitHub Actions to run tests on PRs and to automate some of our
development tasks. These all live in `.github/workflows`.
| Name | File | Trigger | Description |
|----------------------------|----------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------|
| ROCm GPU CI | `rocm-ci.yml` | Open or commit changes to a PR targeting `rocm-main` | Builds and runs JAX on ROCm for PRs going into `rocm-main` |
| ROCm Open Upstream PR | `rocm-open-upstream-pr.yml` | Add the `open-upstream` label to a PR | Copies changes from a PR aimed at `rocm-main` into a new PR aimed at upstream's `main` |
| ROCm Nightly Upstream Sync | `rocm-nightly-upstream-sync.yml` | Runs nightly, can be triggered manually via Actions | Opens a PR that merges changes from upstream `main` into our `rocm-main` branch |

@ -450,6 +450,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product( @jtu.sample_product(
[dict(l_max=l_max, num_z=num_z) [dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])

@ -262,6 +262,7 @@ class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
@parameterized.parameters(('left',), ('right',)) @parameterized.parameters(('left',), ('right',))
def test_interpret_remote_dma_ppermute(self, permutation): def test_interpret_remote_dma_ppermute(self, permutation):
self.skipTest("ROCm: Skipping for now")
if jax.device_count() <= 1: if jax.device_count() <= 1:
self.skipTest('Test requires multiple devices.') self.skipTest('Test requires multiple devices.')
num_devices = jax.device_count() num_devices = jax.device_count()

@ -21,15 +21,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum # curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update XLA_SHA256 with the result. # and update XLA_SHA256 with the result.
XLA_COMMIT = "85234bd3d7ad4514d5dd6df76d6683ea13489840" XLA_COMMIT = "87f7f56cb1ca6aa90fee6128774346bfa83c29f6"
XLA_SHA256 = "175391000a4d454d040b890219318e5f6028d2a0f37f60b6b3bf388254a880e0" XLA_SHA256 = "178166e7e0c4cadd2ad0b016ab89cd90380e6ceffde3610f36857a9b659ae255"
def repo(): def repo():
tf_http_archive( tf_http_archive(
name = "xla", name = "xla",
sha256 = XLA_SHA256, sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
) )
# For development, one often wants to make changes to the TF repository as well # For development, one often wants to make changes to the TF repository as well