Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1

This commit is contained in:
GitHub Actions 2025-03-12 16:57:18 +00:00
commit a0edd3fbb2
22 changed files with 368 additions and 139 deletions

View File

@ -130,6 +130,8 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
build:clang --copt=-Qunused-arguments build:clang --copt=-Qunused-arguments
# Error on struct/class mismatches, since this causes link failures on Windows. # Error on struct/class mismatches, since this causes link failures on Windows.
build:clang --copt=-Werror=mismatched-tags build:clang --copt=-Werror=mismatched-tags
# Don't error out on C++23 extensions. Needed for building the clang-19.
build:clang --copt=-Wno-error=c23-extensions
# Configs for CUDA # Configs for CUDA
build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NEED_CUDA=1

3
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,3 @@
# Require approvals from someone on the JAX team before PRs are merged
* @ROCm/jax-devs

View File

@ -1,4 +1,4 @@
name: CI name: ROCm CPU CI
# We test all supported Python versions as follows: # We test all supported Python versions as follows:
# - 3.10 : Documentation build # - 3.10 : Documentation build
@ -11,10 +11,10 @@ on:
# but only for the main branch # but only for the main branch
push: push:
branches: branches:
- main - rocm-main
pull_request: pull_request:
branches: branches:
- main - rocm-main
permissions: permissions:
contents: read # to fetch code contents: read # to fetch code
@ -42,12 +42,8 @@ jobs:
- run: pre-commit run --show-diff-on-failure --color=always --all-files - run: pre-commit run --show-diff-on-failure --color=always --all-files
build: build:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})" name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
runs-on: linux-x86-n2-32 runs-on: ROCM-Ubuntu
container:
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 60 timeout-minutes: 60
strategy: strategy:
matrix: matrix:
@ -65,10 +61,6 @@ jobs:
num_generated_cases: 1 num_generated_cases: 1
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with: with:
@ -95,12 +87,12 @@ jobs:
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE" echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 tests examples pytest -n 4 --tb=short --maxfail=20 tests examples
documentation: documentation:
name: Documentation - test code snippets name: Documentation - test code snippets
runs-on: ubuntu-latest runs-on: ROCM-Ubuntu
timeout-minutes: 10 timeout-minutes: 10
strategy: strategy:
matrix: matrix:
@ -128,19 +120,13 @@ jobs:
documentation_render: documentation_render:
name: Documentation - render documentation name: Documentation - render documentation
runs-on: linux-x86-n2-16 runs-on: ubuntu-latest
container: timeout-minutes: 20
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
timeout-minutes: 10
strategy: strategy:
matrix: matrix:
python-version: ['3.10'] python-version: ['3.10']
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Image Setup
run: |
apt update
apt install -y libssl-dev libsqlite3-dev
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
with: with:
@ -194,9 +180,7 @@ jobs:
ffi: ffi:
name: FFI example name: FFI example
runs-on: linux-x86-g2-16-l4-1gpu runs-on: ROCM-Ubuntu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@ -207,7 +191,8 @@ jobs:
- name: Install JAX - name: Install JAX
run: | run: |
pip install uv~=0.5.30 pip install uv~=0.5.30
uv pip install --system .[cuda12] pip install uv
uv pip install --system .
- name: Build and install example project - name: Build and install example project
run: uv pip install --system ./examples/ffi[test] run: uv pip install --system ./examples/ffi[test]
env: env:
@ -216,10 +201,11 @@ jobs:
# a different toolchain. GCC is the default compiler on the # a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be # 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear. # clear.
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
- name: Run CPU tests - name: Run CPU tests
run: python -m pytest examples/ffi/tests run: python -m pytest examples/ffi/tests
env: env:
JAX_PLATFORM_NAME: cpu JAX_PLATFORM_NAME: cpu
- name: Run GPU tests - name: Run GPU tests
run: python -m pytest examples/ffi/tests run: python -m pytest examples/ffi/tests

View File

@ -1,23 +1,27 @@
name: ROCm GPU Post-Merge Check name: ROCm GPU CI
on: on:
# Trigger the workflow after a push into the main branch # Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push: push:
branches: branches:
- main - rocm-main
- 'rocm-jaxlib-v*'
permissions: pull_request:
contents: read branches:
- rocm-main
- 'rocm-jaxlib-v*'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs: jobs:
build-jax-in-docker: build-jax-in-docker: # strategy and matrix come here
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250 runs-on: mi-250
env: env:
BASE_IMAGE: "ubuntu:22.04" BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10" PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4" ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
@ -32,6 +36,9 @@ jobs:
ls ls
- name: Print system info - name: Print system info
run: | run: |
whoami
printenv
df -h
rocm-smi rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with: with:
@ -50,9 +57,11 @@ jobs:
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl path: ./dist/*.whl
retention-days: 2
- name: Run tests - name: Run tests
env:
GPU_COUNT: "8"
GFX: "gfx90a"
run: | run: |
cd $WORKSPACE_DIR cd $WORKSPACE_DIR
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py" python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"

View File

@ -0,0 +1,66 @@
# Pulls the latest changes from upstream into main and opens a PR to merge
# them into rocm-main branch.
name: ROCm Nightly Upstream Sync
on:
workflow_dispatch:
schedule:
- cron: '0 6 * * 1-5'
permissions:
contents: write
pull-requests: write
env:
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
jobs:
sync-main:
runs-on: ubuntu-latest
steps:
- name: Generate an app token
id: generate-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
- name: Sync our main with upstream main
run: |
gh auth status
gh repo sync rocm/jax -b main
env:
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
create-sync-branch:
needs: sync-main
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Create branch
run: |
git fetch
git checkout origin/main
git checkout -b $SYNC_BRANCH_NAME
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
git config --global user.email "github-actions@github.com"
git config --global user.name "GitHub Actions"
git merge origin/rocm-main || true
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
git merge --abort || true
git push origin HEAD
open-sync-pr:
needs: create-sync-branch
runs-on: ubuntu-latest
steps:
- name: Generate an app token
id: generate-token
uses: actions/create-github-app-token@v1
with:
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
- name: Open a PR to rocm-main
run: |
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
env:
GH_TOKEN: ${{ steps.generate-token.outputs.token }}

View File

@ -0,0 +1,41 @@
name: ROCm Open Upstream PR
on:
pull_request:
types: [ labeled ]
branches: [ rocm-main ]
jobs:
open-upstream:
if: ${{ github.event.label.name == 'open-upstream' }}
permissions:
contents: write
pull-requests: write
runs-on: ubuntu-latest
env:
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Rebase code to main
run: |
git config --global user.email "github-actions@github.com"
git config --global user.name "Github Actions"
git fetch
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
git rebase --onto origin/main origin/rocm-main
# Force push here so that we don't run into conflicts with the origin branch
git push origin HEAD --force
- name: Leave link to create PR
env:
GH_TOKEN: ${{ github.token }}
run: |
# Bash is not friendly with newline characters, so make our own
NL=$'\n'
# Encode the PR title and body for passing as URL get parameters
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
# Create a link to the that will open up a new PR form to upstream and autofill the fields
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
# Add a comment with the link to the PR
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"

View File

@ -22,7 +22,7 @@ on:
jobs: jobs:
upstream-dev: upstream-dev:
runs-on: ubuntu-latest runs-on: ROCM-Ubuntu
permissions: permissions:
contents: read contents: read
issues: write # for failed-build-issue issues: write # for failed-build-issue

View File

@ -262,7 +262,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
rocm_group.add_argument( rocm_group.add_argument(
"--rocm_amdgpu_targets", "--rocm_amdgpu_targets",
type=str, type=str,
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201", default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
help="A comma-separated list of ROCm amdgpu targets to support.", help="A comma-separated list of ROCm amdgpu targets to support.",
) )

View File

@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for # Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201" ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCm # Install ROCm
@ -62,7 +62,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \ pytest-reportlog \
pytest-rerunfailures \ pytest-rerunfailures \
pytest-json-report \ pytest-json-report \
pytest-csv \
cloudpickle \ cloudpickle \
portpicker \ portpicker \
matplotlib \ matplotlib \

View File

@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
### Simplified Build Script ### Simplified Build Script
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script. For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.

View File

@ -9,13 +9,13 @@ ARG ROCM_BUILD_NUM
# manylinux base image. However, adding this does fix an issue where Bazel isn't able # manylinux base image. However, adding this does fix an issue where Bazel isn't able
# to find them. # to find them.
RUN --mount=type=cache,target=/var/cache/dnf \ RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel
RUN --mount=type=cache,target=/var/cache/dnf \ RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201" ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS} RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}
# Install LLVM 18 and dependencies. # Install LLVM 18 and dependencies.

View File

@ -127,6 +127,31 @@ def dist_wheels(
] ]
) )
# Add command for unit tests
cmd.extend(
[
"&&",
"bazel",
"test",
"-k",
"--jobs=4",
"--test_verbose_timeout_warnings=true",
"--test_output=all",
"--test_summary=detailed",
"--local_test_jobs=1",
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
"--test_env=JAX_SKIP_SLOW_TESTS=0",
"--verbose_failures=true",
"--config=rocm",
"--action_env=ROCM_PATH=/opt/rocm",
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
"--test_tag_filters=-multiaccelerator",
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
"//tests:gpu_tests",
]
)
LOG.info("Running: %s", cmd) LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True) _ = subprocess.run(cmd, check=True)
@ -356,3 +381,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for # Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201" ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM # Install ROCM
@ -47,7 +47,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \ pytest-reportlog \
pytest-rerunfailures \ pytest-rerunfailures \
pytest-json-report \ pytest-json-report \
pytest-csv \
cloudpickle \ cloudpickle \
portpicker \ portpicker \
matplotlib \ matplotlib \

View File

@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for # Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201" ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM # Install ROCM
@ -46,7 +46,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pytest-reportlog \ pytest-reportlog \
pytest-rerunfailures \ pytest-rerunfailures \
pytest-json-report \ pytest-json-report \
pytest-csv \
cloudpickle \ cloudpickle \
portpicker \ portpicker \
matplotlib \ matplotlib \

View File

@ -54,11 +54,15 @@ run_tests() {
python3 -m pytest \ python3 -m pytest \
--html="${LOG_DIR}/multi_gpu_pmap_test_log.html" \ --html="${LOG_DIR}/multi_gpu_pmap_test_log.html" \
--json-report \
--json-report-file="${LOG_DIR}/multi_gpu_pmap_test_log.json" \
--reruns 3 \ --reruns 3 \
tests/pmap_test.py tests/pmap_test.py
python3 -m pytest \ python3 -m pytest \
--html="${LOG_DIR}/multi_gpu_multi_device_test_log.html" \ --html="${LOG_DIR}/multi_gpu_multi_device_test_log.html" \
--json-report \
--json-report-file="${LOG_DIR}/multi_gpu_multi_device_test_log.json" \
--reruns 3 \ --reruns 3 \
tests/multi_device_test.py tests/multi_device_test.py

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import os import os
import csv
import json import json
import argparse import argparse
import threading import threading
@ -43,22 +42,6 @@ def combine_json_reports():
with open(combined_json_file, "w") as outfile: with open(combined_json_file, "w") as outfile:
json.dump(combined_data, outfile, indent=4) json.dump(combined_data, outfile, indent=4)
def combine_csv_reports():
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
with open(combined_csv_file, mode="w", newline="") as outfile:
csv_writer = csv.writer(outfile)
for i, csv_file in enumerate(all_csv_files):
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
csv_reader = csv.reader(infile)
if i == 0:
# write headers only once
csv_writer.writerow(next(csv_reader))
for row in csv_reader:
csv_writer.writerow(row)
def generate_final_report(shell=False, env_vars={}): def generate_final_report(shell=False, env_vars={}):
env = os.environ env = os.environ
env = {**env, **env_vars} env = {**env, **env_vars}
@ -76,8 +59,6 @@ def generate_final_report(shell=False, env_vars={}):
# Generate json reports. # Generate json reports.
combine_json_reports() combine_json_reports()
# Generate csv reports.
combine_csv_reports()
def run_shell_command(cmd, shell=False, env_vars={}): def run_shell_command(cmd, shell=False, env_vars={}):
@ -147,9 +128,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
"pytest", "pytest",
"--json-report", "--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json", f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html", f"--html={base_dir}/{testfile}_log.html",
"--reruns", "--reruns",
"3", "3",
@ -163,9 +141,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
"pytest", "pytest",
"--json-report", "--json-report",
f"--json-report-file={base_dir}/{testfile}_log.json", f"--json-report-file={base_dir}/{testfile}_log.json",
f"--csv={base_dir}/{testfile}_log.csv",
"--csv-columns",
"id,module,name,file,status,duration",
f"--html={base_dir}/{testfile}_log.html", f"--html={base_dir}/{testfile}_log.html",
"--reruns", "--reruns",
"3", "3",

View File

@ -16,77 +16,76 @@ ROCM_BUILD_NUM=main
# Intial release don't have the trialing '.0' # Intial release don't have the trialing '.0'
# For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/ # For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/
if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then
ROCM_VERS=${ROCM_VERSION%.*} ROCM_VERS=${ROCM_VERSION%.*}
else else
ROCM_VERS=$ROCM_VERSION ROCM_VERS=$ROCM_VERSION
fi fi
ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/ ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/
if [ ! -f "/${CUSTOM_INSTALL}" ]; then if [ ! -f "/${CUSTOM_INSTALL}" ]; then
# Add rocm repository # Add rocm repository
chmod 1777 /tmp chmod 1777 /tmp
DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update
DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common
DEBIAN_FRONTEND=noninteractive apt-get clean all DEBIAN_FRONTEND=noninteractive apt-get clean all
wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -; wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]] ; then \ if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]]; then
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \ echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
else \ else
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \ echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
fi fi
#Install rocm and other packages #Install rocm and other packages
apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \ apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \ build-essential \
software-properties-common \ software-properties-common \
clang-6.0 \ clang-6.0 \
clang-format-6.0 \ clang-format-6.0 \
curl \ curl \
g++-multilib \ g++-multilib \
git \ git \
vim \ vim \
libnuma-dev \ libnuma-dev \
virtualenv \ virtualenv \
python3-pip \ python3-pip \
pciutils \ pciutils \
python-is-python3 \ python-is-python3 \
libffi-dev \ libffi-dev \
libssl-dev \ libssl-dev \
build-essential \ build-essential \
zlib1g-dev \ zlib1g-dev \
libbz2-dev \ libbz2-dev \
libreadline-dev \ libreadline-dev \
libsqlite3-dev curl \ libsqlite3-dev curl \
libncursesw5-dev \ libncursesw5-dev \
xz-utils \ xz-utils \
tk-dev \ tk-dev \
libxml2-dev \ libxml2-dev \
libxmlsec1-dev \ libxmlsec1-dev \
libffi-dev \ libffi-dev \
liblzma-dev \ liblzma-dev \
wget \ wget \
rocm-dev \ rocm-dev \
rocm-libs \ rocm-libs \
miopen-hip \ miopen-hip \
miopen-hip-dev \ miopen-hip-dev \
rocblas \ rocblas \
rocblas-dev \ rocblas-dev \
rocsolver-dev \ rocsolver-dev \
rocrand-dev \ rocrand-dev \
rocfft-dev \ rocfft-dev \
hipfft-dev \ hipfft-dev \
hipblas-dev \ hipblas-dev \
rocprim-dev \ rocprim-dev \
hipcub-dev \ hipcub-dev \
rccl-dev \ rccl-dev \
hipsparse-dev \ hipsparse-dev \
hipsolver-dev \ hipsolver-dev \
wget && \ wget &&
apt-get clean && \ apt-get clean &&
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
else else
bash "/${CUSTOM_INSTALL}" bash "/${CUSTOM_INSTALL}"
fi fi
echo $ROCM_VERSION echo $ROCM_VERSION
@ -95,6 +94,6 @@ echo $ROCM_PATH
echo $GPU_DEVICE_TARGETS echo $GPU_DEVICE_TARGETS
# Ensure the ROCm target list is set up # Ensure the ROCm target list is set up
GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"} GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"}
printf '%s\n' ${GPU_DEVICE_TARGETS} | tee -a "$ROCM_PATH/bin/target.lst" printf '%s\n' ${GPU_DEVICE_TARGETS} | tee -a "$ROCM_PATH/bin/target.lst"
touch "${ROCM_PATH}/.info/version" touch "${ROCM_PATH}/.info/version"

View File

@ -35,7 +35,7 @@ import sys
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201" GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
def build_rocm_path(rocm_version_str): def build_rocm_path(rocm_version_str):

View File

@ -0,0 +1,53 @@
#!/bin/bash
# Check for user-supplied arguments.
if [[ $# -lt 2 ]]; then
echo "Usage: $0 <jax_home_directory> <version>"
exit 1
fi
# Set JAX_HOME and RELEASE_VERSION from user arguments.
JAX_HOME=$1
RELEASE_VERSION=$2
WHEELHOUSE="$JAX_HOME/wheelhouse"
# Projects to upload separately to PyPI.
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
# PyPI API Token.
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
# Ensure the specified JAX_HOME and wheelhouse directories exists.
if [[ ! -d "$JAX_HOME" ]]; then
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
exit 1
fi
if [[ ! -d "$WHEELHOUSE" ]]; then
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
exit 1
fi
upload_and_release_project() {
local project=$1
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
if [[ ${#wheels[@]} -eq 0 ]]; then
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
return
fi
echo "Found wheels for $project: ${wheels[*]}"
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
for wheel in "${wheels[@]}"; do
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
done
}
# Install twine if not already installed.
python -m pip install --upgrade twine
# Iterate over each project and upload its wheels.
for project in "${PROJECTS[@]}"; do
upload_and_release_project $project
done

View File

@ -0,0 +1,65 @@
# ROCm CI Dev Guide
This guide lays out how to do some dev operations, what branches live in this repo, and what CI workflows live in this repo.
# Quick Tips
1. Always use "Squash and Merge" when merging PRs into `rocm-main` (unless you're merging the daily sync from upstream).
2. When submitting a PR to `rocm-main`, make sure that your feature branch started from `rocm-main`. When you started working on your feature, did you do `git checkout rocm-main && git checkout -b <my feature branch>`?
3. Always fill out your PR's description with an explanation of why we need this change to give context to your fellow devs (and your future self).
4. In the PR description, link to the story or GitHub issue that this PR solves.
# Processes
## Making a Change
1. Clone `rocm/jax` and check out the `rocm-main` branch.
2. Create a new feature branch with `git checkout -b <my feature name>`.
3. Make your changes on the feature branch and test them locally.
4. Push your changes to a new feature branch in `rocm/jax` by running
`git push orgin HEAD`.
5. Open a PR from your new feature branch into `rocm-main` with a nice description telling your
team members what the change is for. Bonus points if you can link it to an issue or story.
6. Add reviewers, wait for approval, and make sure CI passes.
7. Depending on if your specific change, either:
a. If this is a normal, run-of-the-mill change that we want to put upstream, add the
`open-upstream` label to your PR and close your PR. In a few minutes, Actions will
comment on your PR with a link that lets you open a new PR into upstream. The link will
autofill some PR info, and the new PR be created on a new branch that has the same name
as your old feature branch, but with the `-upstream` suffix appended to the end of it.
If upstream reviewers request some changes to the new PR before merging, you can add
or modify commits on the new `-upstream` feature branch.
b. If this is an urgent change that we want in `rocm-main` right now but also want upstream,
add the `open-upstream` label, merge your PR, and then follow the link that
c. If this is a change that we only want to keep in `rocm/jax` and not push into upstream,
squash and merge your PR.
If you submitted your PR upstream with `open-upstream`, you should see your change in `rocm-main`
the next time the `ROCm Nightly Upstream Sync` workflow is run and the PR that it creates is
merged.
When using the `open-upstream` label to move changes to upstream, it's best to put the label on the PR when you either close or merge the PR. The GitHub Actions workflow that handles the `open-upstream` label uses `git rebase --onto` to set up the changes destined for upstream. Adding the label and creating this branch long after the PR has been merged or closed can cause merge conflicts with new upstream code and cause the workflow to fail. Adding the label right after creating your PR means that 1) any changes you make to your downstream PR while it is in review won't make it to upstream, and it is up to you to cherry-pick those changes into the upstream branch or remove and re-add the `open-upstream` label to get the Actions workflow to do it for you, and 2) that you're proposing changes to upstream that the rest of the AMD team might still have comments on.
## Daily Upstream Sync
Every day, GitHub Actions will attempt to run the `ROCm Nightly Upstream Sync` workflow. This job
normally does this on its own, but requires a developer to intervene if there's a merge conflict
or if the PR fails CI. Devs should fix or resolve issues with the merge by adding commits to the
PR's branch.
# Branches
* `rocm-main` - the default "trunk" branch for this repo. Should only be changed submitting PRs to it from feature branches created by devs.
* `main` - a copy of `jax-ml/jax:main`. This branch is "read-only" and should only be changed by GitHub Actions.
# CI Workflows
We use GitHub Actions to run tests on PRs and to automate some of our
development tasks. These all live in `.github/workflows`.
| Name | File | Trigger | Description |
|----------------------------|----------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------|
| ROCm GPU CI | `rocm-ci.yml` | Open or commit changes to a PR targeting `rocm-main` | Builds and runs JAX on ROCm for PRs going into `rocm-main` |
| ROCm Open Upstream PR | `rocm-open-upstream-pr.yml` | Add the `open-upstream` label to a PR | Copies changes from a PR aimed at `rocm-main` into a new PR aimed at upstream's `main` |
| ROCm Nightly Upstream Sync | `rocm-nightly-upstream-sync.yml` | Runs nightly, can be triggered manually via Actions | Opens a PR that merges changes from upstream `main` into our `rocm-main` branch |

View File

@ -450,6 +450,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
@jtu.sample_product( @jtu.sample_product(
[dict(l_max=l_max, num_z=num_z) [dict(l_max=l_max, num_z=num_z)
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])

View File

@ -262,6 +262,7 @@ class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
@parameterized.parameters(('left',), ('right',)) @parameterized.parameters(('left',), ('right',))
def test_interpret_remote_dma_ppermute(self, permutation): def test_interpret_remote_dma_ppermute(self, permutation):
self.skipTest("ROCm: Skipping for now")
if jax.device_count() <= 1: if jax.device_count() <= 1:
self.skipTest('Test requires multiple devices.') self.skipTest('Test requires multiple devices.')
num_devices = jax.device_count() num_devices = jax.device_count()