Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1

This commit is contained in:
GitHub Actions 2025-03-12 16:57:18 +00:00
commit a0edd3fbb2
22 changed files with 368 additions and 139 deletions

View File

@ -130,6 +130,8 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags
# Don't error out on C++23 extensions. Needed for building the clang-19.
build:clang --copt=-Wno-error=c23-extensions
# Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1

3
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,3 @@
# Require approvals from someone on the JAX team before PRs are merged
* @ROCm/jax-devs

View File

@ -1,4 +1,4 @@
name: CI
name: ROCm CPU CI
# We test all supported Python versions as follows:
# - 3.10 : Documentation build
@ -11,10 +11,10 @@ on:
# but only for the main branch
push:
branches:
- main
- rocm-main
pull_request:
branches:
- main
- rocm-main
permissions:
contents: read # to fetch code
@ -42,12 +42,8 @@ jobs:
- run: pre-commit run --show-diff-on-failure --color=always --all-files
build:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
runs-on: linux-x86-n2-32
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
runs-on: ROCM-Ubuntu
timeout-minutes: 60
strategy:
matrix:
@ -65,10 +61,6 @@ jobs:
num_generated_cases: 1
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
@ -95,12 +87,12 @@ jobs:
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples
pytest -n 4 --tb=short --maxfail=20 tests examples
documentation:
name: Documentation - test code snippets
runs-on: ubuntu-latest
runs-on: ROCM-Ubuntu
timeout-minutes: 10
strategy:
matrix:
@ -128,19 +120,13 @@ jobs:
documentation_render:
name: Documentation - render documentation
runs-on: linux-x86-n2-16
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 10
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
matrix:
python-version: ['3.10']
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev libsqlite3-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with:
@ -194,9 +180,7 @@ jobs:
ffi:
name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
runs-on: ROCM-Ubuntu
timeout-minutes: 30
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -207,7 +191,8 @@ jobs:
- name: Install JAX
run: |
pip install uv~=0.5.30
uv pip install --system .[cuda12]
pip install uv
uv pip install --system .
- name: Build and install example project
run: uv pip install --system ./examples/ffi[test]
env:
@ -216,10 +201,11 @@ jobs:
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
- name: Run CPU tests
run: python -m pytest examples/ffi/tests
env:
JAX_PLATFORM_NAME: cpu
- name: Run GPU tests
run: python -m pytest examples/ffi/tests

View File

@ -1,23 +1,27 @@
name: ROCm GPU Post-Merge Check
name: ROCm GPU CI
on:
# Trigger the workflow after a push into the main branch
# Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push:
branches:
- main
permissions:
contents: read
- rocm-main
- 'rocm-jaxlib-v*'
pull_request:
branches:
- rocm-main
- 'rocm-jaxlib-v*'
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-jax-in-docker:
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250
build-jax-in-docker: # strategy and matrix come here
runs-on: mi-250
env:
BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
@ -32,6 +36,9 @@ jobs:
ls
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
@ -50,9 +57,11 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
retention-days: 2
path: ./dist/*.whl
- name: Run tests
env:
GPU_COUNT: "8"
GFX: "gfx90a"
run: |
cd $WORKSPACE_DIR
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"

View File

@ -0,0 +1,66 @@
# Pulls the latest changes from upstream into main and opens a PR to merge
# them into rocm-main branch.
name: ROCm Nightly Upstream Sync
on:
workflow_dispatch:
schedule:
- cron: '0 6 * * 1-5'
permissions:
contents: write
pull-requests: write
env:
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
jobs:
sync-main:
runs-on: ubuntu-latest
steps:
- name: Generate an app token
id: generate-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
- name: Sync our main with upstream main
run: |
gh auth status
gh repo sync rocm/jax -b main
env:
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
create-sync-branch:
needs: sync-main
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Create branch
run: |
git fetch
git checkout origin/main
git checkout -b $SYNC_BRANCH_NAME
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
git config --global user.email "github-actions@github.com"
git config --global user.name "GitHub Actions"
git merge origin/rocm-main || true
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
git merge --abort || true
git push origin HEAD
open-sync-pr:
needs: create-sync-branch
runs-on: ubuntu-latest
steps:
- name: Generate an app token
id: generate-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
- name: Open a PR to rocm-main
run: |
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
env:
GH_TOKEN: ${{ steps.generate-token.outputs.token }}

View File

@ -0,0 +1,41 @@
name: ROCm Open Upstream PR
on:
pull_request:
types: [ labeled ]
branches: [ rocm-main ]
jobs:
open-upstream:
if: ${{ github.event.label.name == 'open-upstream' }}
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Rebase code to main
run: |
git config --global user.email "github-actions@github.com"
git config --global user.name "Github Actions"
git fetch
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
git rebase --onto origin/main origin/rocm-main
# Force push here so that we don't run into conflicts with the origin branch
git push origin HEAD --force
- name: Leave link to create PR
env:
GH_TOKEN: ${{ github.token }}
run: |
# Bash is not friendly with newline characters, so make our own
NL=$'\n'
# Encode the PR title and body for passing as URL get parameters
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
# Create a link to the that will open up a new PR form to upstream and autofill the fields
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
# Add a comment with the link to the PR
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"

View File

@ -22,7 +22,7 @@ on:
jobs:
upstream-dev:
runs-on: ubuntu-latest
runs-on: ROCM-Ubuntu
permissions:
contents: read
issues: write # for failed-build-issue

View File

@ -262,7 +262,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
rocm_group.add_argument(
"--rocm_amdgpu_targets",
type=str,
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201",
default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
help="A comma-separated list of ROCm amdgpu targets to support.",
)

View File

@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCm
@ -62,7 +62,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \
pytest-rerunfailures \
pytest-json-report \
pytest-csv \
cloudpickle \
portpicker \
matplotlib \

View File

@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
### Simplified Build Script
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.

View File

@ -9,13 +9,13 @@ ARG ROCM_BUILD_NUM
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
# to find them.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel
RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}
# Install LLVM 18 and dependencies.

View File

@ -127,6 +127,31 @@ def dist_wheels(
]
)
# Add command for unit tests
cmd.extend(
[
"&&",
"bazel",
"test",
"-k",
"--jobs=4",
"--test_verbose_timeout_warnings=true",
"--test_output=all",
"--test_summary=detailed",
"--local_test_jobs=1",
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
"--test_env=JAX_SKIP_SLOW_TESTS=0",
"--verbose_failures=true",
"--config=rocm",
"--action_env=ROCM_PATH=/opt/rocm",
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
"--test_tag_filters=-multiaccelerator",
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
"//tests:gpu_tests",
]
)
LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
@ -356,3 +381,4 @@ def main():
if __name__ == "__main__":
main()

View File

@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM
@ -47,7 +47,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \
pytest-rerunfailures \
pytest-json-report \
pytest-csv \
cloudpickle \
portpicker \
matplotlib \

View File

@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM
@ -46,7 +46,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \
pytest-rerunfailures \
pytest-json-report \
pytest-csv \
cloudpickle \
portpicker \
matplotlib \

View File

@ -54,11 +54,15 @@ run_tests() {
python3 -m pytest \
--html="${LOG_DIR}/multi_gpu_pmap_test_log.html" \
--json-report \
--json-report-file="${LOG_DIR}/multi_gpu_pmap_test_log.json" \
--reruns 3 \
tests/pmap_test.py
python3 -m pytest \
--html="${LOG_DIR}/multi_gpu_multi_device_test_log.html" \
--json-report \
--json-report-file="${LOG_DIR}/multi_gpu_multi_device_test_log.json" \
--reruns 3 \
tests/multi_device_test.py

View File

@ -14,7 +14,6 @@
# limitations under the License.
import os
import csv
import json
import argparse
import threading
@ -43,22 +42,6 @@ def combine_json_reports():
with open(combined_json_file, "w") as outfile:
json.dump(combined_data, outfile, indent=4)
def combine_csv_reports():
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
with open(combined_csv_file, mode="w", newline="") as outfile:
csv_writer = csv.writer(outfile)
for i, csv_file in enumerate(all_csv_files):
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
csv_reader = csv.reader(infile)
if i == 0:
# write headers only once
csv_writer.writerow(next(csv_reader))
for row in csv_reader:
csv_writer.writerow(row)
def generate_final_report(shell=False, env_vars={}):
env = os.environ
env = {**env, **env_vars}
@ -76,8 +59,6 @@ def generate_final_report(shell=False, env_vars={}):
# Generate json reports.
combine_json_reports()
# Generate csv reports.
combine_csv_reports()
def run_shell_command(cmd, shell=False, env_vars={}):
@ -147,9 +128,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",
@ -163,9 +141,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
"pytest",
"--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html",
"--reruns",
"3",

View File

@ -16,77 +16,76 @@ ROCM_BUILD_NUM=main
# Intial release don't have the trialing '.0'
# For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/
if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then
ROCM_VERS=${ROCM_VERSION%.*}
ROCM_VERS=${ROCM_VERSION%.*}
else
ROCM_VERS=$ROCM_VERSION
ROCM_VERS=$ROCM_VERSION
fi
ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/
if [ ! -f "/${CUSTOM_INSTALL}" ]; then
# Add rocm repository
chmod 1777 /tmp
DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update
DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common
DEBIAN_FRONTEND=noninteractive apt-get clean all
wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]] ; then \
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
else \
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
fi
#Install rocm and other packages
apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
software-properties-common \
clang-6.0 \
clang-format-6.0 \
curl \
g++-multilib \
git \
vim \
libnuma-dev \
virtualenv \
python3-pip \
pciutils \
python-is-python3 \
libffi-dev \
libssl-dev \
build-essential \
zlib1g-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev curl \
libncursesw5-dev \
xz-utils \
tk-dev \
libxml2-dev \
libxmlsec1-dev \
libffi-dev \
liblzma-dev \
wget \
rocm-dev \
rocm-libs \
miopen-hip \
miopen-hip-dev \
rocblas \
rocblas-dev \
rocsolver-dev \
rocrand-dev \
rocfft-dev \
hipfft-dev \
hipblas-dev \
rocprim-dev \
hipcub-dev \
rccl-dev \
hipsparse-dev \
hipsolver-dev \
wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Add rocm repository
chmod 1777 /tmp
DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update
DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common
DEBIAN_FRONTEND=noninteractive apt-get clean all
wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]]; then
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
else
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
fi
#Install rocm and other packages
apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
software-properties-common \
clang-6.0 \
clang-format-6.0 \
curl \
g++-multilib \
git \
vim \
libnuma-dev \
virtualenv \
python3-pip \
pciutils \
python-is-python3 \
libffi-dev \
libssl-dev \
build-essential \
zlib1g-dev \
libbz2-dev \
libreadline-dev \
libsqlite3-dev curl \
libncursesw5-dev \
xz-utils \
tk-dev \
libxml2-dev \
libxmlsec1-dev \
libffi-dev \
liblzma-dev \
wget \
rocm-dev \
rocm-libs \
miopen-hip \
miopen-hip-dev \
rocblas \
rocblas-dev \
rocsolver-dev \
rocrand-dev \
rocfft-dev \
hipfft-dev \
hipblas-dev \
rocprim-dev \
hipcub-dev \
rccl-dev \
hipsparse-dev \
hipsolver-dev \
wget &&
apt-get clean &&
rm -rf /var/lib/apt/lists/*
else
bash "/${CUSTOM_INSTALL}"
bash "/${CUSTOM_INSTALL}"
fi
echo $ROCM_VERSION
@ -95,6 +94,6 @@ echo $ROCM_PATH
echo $GPU_DEVICE_TARGETS
# Ensure the ROCm target list is set up
GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"}
GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"}
printf '%s\n' ${GPU_DEVICE_TARGETS} | tee -a "$ROCM_PATH/bin/target.lst"
touch "${ROCM_PATH}/.info/version"

View File

@ -35,7 +35,7 @@ import sys
LOG = logging.getLogger(__name__)
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
def build_rocm_path(rocm_version_str):

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Check for user-supplied arguments.
if [[ $# -lt 2 ]]; then
echo "Usage: $0 <jax_home_directory> <version>"
exit 1
fi
# Set JAX_HOME and RELEASE_VERSION from user arguments.
JAX_HOME=$1
RELEASE_VERSION=$2
WHEELHOUSE="$JAX_HOME/wheelhouse"
# Projects to upload separately to PyPI.
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
# PyPI API Token.
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
# Ensure the specified JAX_HOME and wheelhouse directories exists.
if [[ ! -d "$JAX_HOME" ]]; then
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
exit 1
fi
if [[ ! -d "$WHEELHOUSE" ]]; then
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
exit 1
fi
upload_and_release_project() {
local project=$1
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
if [[ ${#wheels[@]} -eq 0 ]]; then
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
return
fi
echo "Found wheels for $project: ${wheels[*]}"
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
for wheel in "${wheels[@]}"; do
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
done
}
# Install twine if not already installed.
python -m pip install --upgrade twine
# Iterate over each project and upload its wheels.
for project in "${PROJECTS[@]}"; do
upload_and_release_project $project
done

View File

@ -0,0 +1,65 @@
# ROCm CI Dev Guide
This guide lays out how to do some dev operations, what branches live in this repo, and what CI workflows live in this repo.
# Quick Tips
1. Always use "Squash and Merge" when merging PRs into `rocm-main` (unless you're merging the daily sync from upstream).
2. When submitting a PR to `rocm-main`, make sure that your feature branch started from `rocm-main`. When you started working on your feature, did you do `git checkout rocm-main && git checkout -b <my feature branch>`?
3. Always fill out your PR's description with an explanation of why we need this change to give context to your fellow devs (and your future self).
4. In the PR description, link to the story or GitHub issue that this PR solves.
# Processes
## Making a Change
1. Clone `rocm/jax` and check out the `rocm-main` branch.
2. Create a new feature branch with `git checkout -b <my feature name>`.
3. Make your changes on the feature branch and test them locally.
4. Push your changes to a new feature branch in `rocm/jax` by running
`git push orgin HEAD`.
5. Open a PR from your new feature branch into `rocm-main` with a nice description telling your
team members what the change is for. Bonus points if you can link it to an issue or story.
6. Add reviewers, wait for approval, and make sure CI passes.
7. Depending on if your specific change, either:
a. If this is a normal, run-of-the-mill change that we want to put upstream, add the
`open-upstream` label to your PR and close your PR. In a few minutes, Actions will
comment on your PR with a link that lets you open a new PR into upstream. The link will
autofill some PR info, and the new PR be created on a new branch that has the same name
as your old feature branch, but with the `-upstream` suffix appended to the end of it.
If upstream reviewers request some changes to the new PR before merging, you can add
or modify commits on the new `-upstream` feature branch.
b. If this is an urgent change that we want in `rocm-main` right now but also want upstream,
add the `open-upstream` label, merge your PR, and then follow the link that
c. If this is a change that we only want to keep in `rocm/jax` and not push into upstream,
squash and merge your PR.
If you submitted your PR upstream with `open-upstream`, you should see your change in `rocm-main`
the next time the `ROCm Nightly Upstream Sync` workflow is run and the PR that it creates is
merged.
When using the `open-upstream` label to move changes to upstream, it's best to put the label on the PR when you either close or merge the PR. The GitHub Actions workflow that handles the `open-upstream` label uses `git rebase --onto` to set up the changes destined for upstream. Adding the label and creating this branch long after the PR has been merged or closed can cause merge conflicts with new upstream code and cause the workflow to fail. Adding the label right after creating your PR means that 1) any changes you make to your downstream PR while it is in review won't make it to upstream, and it is up to you to cherry-pick those changes into the upstream branch or remove and re-add the `open-upstream` label to get the Actions workflow to do it for you, and 2) that you're proposing changes to upstream that the rest of the AMD team might still have comments on.
## Daily Upstream Sync
Every day, GitHub Actions will attempt to run the `ROCm Nightly Upstream Sync` workflow. This job
normally does this on its own, but requires a developer to intervene if there's a merge conflict
or if the PR fails CI. Devs should fix or resolve issues with the merge by adding commits to the
PR's branch.
# Branches
* `rocm-main` - the default "trunk" branch for this repo. Should only be changed submitting PRs to it from feature branches created by devs.
* `main` - a copy of `jax-ml/jax:main`. This branch is "read-only" and should only be changed by GitHub Actions.
# CI Workflows
We use GitHub Actions to run tests on PRs and to automate some of our
development tasks. These all live in `.github/workflows`.
| Name | File | Trigger | Description |
|----------------------------|----------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------|
| ROCm GPU CI | `rocm-ci.yml` | Open or commit changes to a PR targeting `rocm-main` | Builds and runs JAX on ROCm for PRs going into `rocm-main` |
| ROCm Open Upstream PR | `rocm-open-upstream-pr.yml` | Add the `open-upstream` label to a PR | Copies changes from a PR aimed at `rocm-main` into a new PR aimed at upstream's `main` |
| ROCm Nightly Upstream Sync | `rocm-nightly-upstream-sync.yml` | Runs nightly, can be triggered manually via Actions | Opens a PR that merges changes from upstream `main` into our `rocm-main` branch |

View File

@ -450,6 +450,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product(
[dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])

View File

@ -262,6 +262,7 @@ class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
@parameterized.parameters(('left',), ('right',))
def test_interpret_remote_dma_ppermute(self, permutation):
self.skipTest("ROCm: Skipping for now")
if jax.device_count() <= 1:
self.skipTest('Test requires multiple devices.')
num_devices = jax.device_count()