mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00
Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1
This commit is contained in:
commit
a0edd3fbb2
2
.bazelrc
2
.bazelrc
@ -130,6 +130,8 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
|
||||
build:clang --copt=-Qunused-arguments
|
||||
# Error on struct/class mismatches, since this causes link failures on Windows.
|
||||
build:clang --copt=-Werror=mismatched-tags
|
||||
# Don't error out on C++23 extensions. Needed for building the clang-19.
|
||||
build:clang --copt=-Wno-error=c23-extensions
|
||||
|
||||
# Configs for CUDA
|
||||
build:cuda --repo_env TF_NEED_CUDA=1
|
||||
|
3
.github/CODEOWNERS
vendored
Normal file
3
.github/CODEOWNERS
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Require approvals from someone on the JAX team before PRs are merged
|
||||
* @ROCm/jax-devs
|
||||
|
40
.github/workflows/ci-build.yaml
vendored
40
.github/workflows/ci-build.yaml
vendored
@ -1,4 +1,4 @@
|
||||
name: CI
|
||||
name: ROCm CPU CI
|
||||
|
||||
# We test all supported Python versions as follows:
|
||||
# - 3.10 : Documentation build
|
||||
@ -11,10 +11,10 @@ on:
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- rocm-main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- rocm-main
|
||||
|
||||
permissions:
|
||||
contents: read # to fetch code
|
||||
@ -42,12 +42,8 @@ jobs:
|
||||
- run: pre-commit run --show-diff-on-failure --color=always --all-files
|
||||
|
||||
build:
|
||||
# Don't execute in fork due to runner type
|
||||
if: github.repository == 'jax-ml/jax'
|
||||
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
|
||||
runs-on: linux-x86-n2-32
|
||||
container:
|
||||
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
matrix:
|
||||
@ -65,10 +61,6 @@ jobs:
|
||||
num_generated_cases: 1
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Image Setup
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libssl-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
@ -95,12 +87,12 @@ jobs:
|
||||
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
||||
pytest -n auto --tb=short --maxfail=20 tests examples
|
||||
pytest -n 4 --tb=short --maxfail=20 tests examples
|
||||
|
||||
|
||||
documentation:
|
||||
name: Documentation - test code snippets
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
@ -128,19 +120,13 @@ jobs:
|
||||
|
||||
documentation_render:
|
||||
name: Documentation - render documentation
|
||||
runs-on: linux-x86-n2-16
|
||||
container:
|
||||
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
|
||||
timeout-minutes: 10
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.10']
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Image Setup
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libssl-dev libsqlite3-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
@ -194,9 +180,7 @@ jobs:
|
||||
|
||||
ffi:
|
||||
name: FFI example
|
||||
runs-on: linux-x86-g2-16-l4-1gpu
|
||||
container:
|
||||
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
|
||||
runs-on: ROCM-Ubuntu
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@ -207,7 +191,8 @@ jobs:
|
||||
- name: Install JAX
|
||||
run: |
|
||||
pip install uv~=0.5.30
|
||||
uv pip install --system .[cuda12]
|
||||
pip install uv
|
||||
uv pip install --system .
|
||||
- name: Build and install example project
|
||||
run: uv pip install --system ./examples/ffi[test]
|
||||
env:
|
||||
@ -216,10 +201,11 @@ jobs:
|
||||
# a different toolchain. GCC is the default compiler on the
|
||||
# 'ubuntu-latest' runner, but we still set this explicitly just to be
|
||||
# clear.
|
||||
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
||||
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ #-DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
||||
- name: Run CPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
env:
|
||||
JAX_PLATFORM_NAME: cpu
|
||||
- name: Run GPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
|
||||
|
31
.github/workflows/rocm-ci.yml
vendored
31
.github/workflows/rocm-ci.yml
vendored
@ -1,23 +1,27 @@
|
||||
name: ROCm GPU Post-Merge Check
|
||||
name: ROCm GPU CI
|
||||
|
||||
on:
|
||||
# Trigger the workflow after a push into the main branch
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the rocm-main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
- rocm-main
|
||||
- 'rocm-jaxlib-v*'
|
||||
pull_request:
|
||||
branches:
|
||||
- rocm-main
|
||||
- 'rocm-jaxlib-v*'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-jax-in-docker:
|
||||
runs-on: linux-x86_64-cirrascale-64-8gpu-amd-mi250
|
||||
build-jax-in-docker: # strategy and matrix come here
|
||||
runs-on: mi-250
|
||||
env:
|
||||
BASE_IMAGE: "ubuntu:22.04"
|
||||
TEST_IMAGE: ubuntu-jax-upstream-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
PYTHON_VERSION: "3.10"
|
||||
ROCM_VERSION: "6.2.4"
|
||||
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
|
||||
@ -32,6 +36,9 @@ jobs:
|
||||
ls
|
||||
- name: Print system info
|
||||
run: |
|
||||
whoami
|
||||
printenv
|
||||
df -h
|
||||
rocm-smi
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
@ -50,9 +57,11 @@ jobs:
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
|
||||
path: ${{ env.WORKSPACE_DIR }}/dist/*.whl
|
||||
retention-days: 2
|
||||
path: ./dist/*.whl
|
||||
- name: Run tests
|
||||
env:
|
||||
GPU_COUNT: "8"
|
||||
GFX: "gfx90a"
|
||||
run: |
|
||||
cd $WORKSPACE_DIR
|
||||
python3 build/rocm/ci_build test $TEST_IMAGE --test-cmd "pytest tests/core_test.py"
|
||||
|
66
.github/workflows/rocm-nightly-upstream-sync.yml
vendored
Normal file
66
.github/workflows/rocm-nightly-upstream-sync.yml
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
# Pulls the latest changes from upstream into main and opens a PR to merge
|
||||
# them into rocm-main branch.
|
||||
|
||||
name: ROCm Nightly Upstream Sync
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 6 * * 1-5'
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
env:
|
||||
SYNC_BRANCH_NAME: ci-upstream-sync-${{ github.run_number }}_${{ github.run_attempt }}
|
||||
jobs:
|
||||
sync-main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Generate an app token
|
||||
id: generate-token
|
||||
uses: actions/create-github-app-token@v1
|
||||
with:
|
||||
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
|
||||
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
|
||||
- name: Sync our main with upstream main
|
||||
run: |
|
||||
gh auth status
|
||||
gh repo sync rocm/jax -b main
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
|
||||
create-sync-branch:
|
||||
needs: sync-main
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Create branch
|
||||
run: |
|
||||
git fetch
|
||||
git checkout origin/main
|
||||
git checkout -b $SYNC_BRANCH_NAME
|
||||
# Try and merge rocm-main into this new branch so that we don't run upstream's CI code
|
||||
git config --global user.email "github-actions@github.com"
|
||||
git config --global user.name "GitHub Actions"
|
||||
git merge origin/rocm-main || true
|
||||
# If the merge creates conflicts, we want to abort and push to origin anyways so that a dev can resolve the conflicts
|
||||
git merge --abort || true
|
||||
git push origin HEAD
|
||||
open-sync-pr:
|
||||
needs: create-sync-branch
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Generate an app token
|
||||
id: generate-token
|
||||
uses: actions/create-github-app-token@v1
|
||||
with:
|
||||
app-id: ${{ vars.ROCM_REPO_MANAGEMENT_API_2_ID }}
|
||||
private-key: ${{ secrets.ROCM_REPO_MANAGEMENT_API_2_PRIV_KEY }}
|
||||
- name: Open a PR to rocm-main
|
||||
run: |
|
||||
gh pr create --repo $GITHUB_REPOSITORY --head $SYNC_BRANCH_NAME --base rocm-main --title "CI: $(date +%x) upstream sync" --body "Daily sync with upstream"
|
||||
gh pr merge --repo $GITHUB_REPOSITORY --merge --auto $SYNC_BRANCH_NAME
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.generate-token.outputs.token }}
|
||||
|
41
.github/workflows/rocm-open-upstream-pr.yml
vendored
Normal file
41
.github/workflows/rocm-open-upstream-pr.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
name: ROCm Open Upstream PR
|
||||
on:
|
||||
pull_request:
|
||||
types: [ labeled ]
|
||||
branches: [ rocm-main ]
|
||||
jobs:
|
||||
open-upstream:
|
||||
if: ${{ github.event.label.name == 'open-upstream' }}
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
NEW_BRANCH_NAME: "${{ github.head_ref }}-upstream"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Rebase code to main
|
||||
run: |
|
||||
git config --global user.email "github-actions@github.com"
|
||||
git config --global user.name "Github Actions"
|
||||
git fetch
|
||||
git checkout -b $NEW_BRANCH_NAME origin/${{ github.head_ref }}
|
||||
git rebase --onto origin/main origin/rocm-main
|
||||
# Force push here so that we don't run into conflicts with the origin branch
|
||||
git push origin HEAD --force
|
||||
- name: Leave link to create PR
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# Bash is not friendly with newline characters, so make our own
|
||||
NL=$'\n'
|
||||
# Encode the PR title and body for passing as URL get parameters
|
||||
TITLE_ENC=$(jq -rn --arg x "[ROCm] ${{ github.event.pull_request.title }}" '$x|@uri')
|
||||
BODY_ENC=$(jq -rn --arg x $"${{ github.event.pull_request.body }}${NL}${NL}Created from: rocm/jax#${{ github.event.pull_request.number }}" '$x|@uri')
|
||||
# Create a link to the that will open up a new PR form to upstream and autofill the fields
|
||||
CREATE_PR_LINK="https://github.com/jax-ml/jax/compare/main...ROCm:jax:$NEW_BRANCH_NAME?expand=1&title=$TITLE_ENC&body=$BODY_ENC"
|
||||
# Add a comment with the link to the PR
|
||||
COMMENT_BODY="Feature branch from main is ready. [Create a new PR][1] destined for upstream?${NL}${NL}[1]: $CREATE_PR_LINK"
|
||||
gh pr comment ${{ github.event.pull_request.number }} --repo rocm/jax --body "$COMMENT_BODY"
|
||||
|
2
.github/workflows/upstream-nightly.yml
vendored
2
.github/workflows/upstream-nightly.yml
vendored
@ -22,7 +22,7 @@ on:
|
||||
|
||||
jobs:
|
||||
upstream-dev:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ROCM-Ubuntu
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write # for failed-build-issue
|
||||
|
@ -262,7 +262,7 @@ def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
||||
rocm_group.add_argument(
|
||||
"--rocm_amdgpu_targets",
|
||||
type=str,
|
||||
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201",
|
||||
default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
|
||||
help="A comma-separated list of ROCm amdgpu targets to support.",
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCm
|
||||
@ -62,7 +62,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
pytest-reportlog \
|
||||
pytest-rerunfailures \
|
||||
pytest-json-report \
|
||||
pytest-csv \
|
||||
cloudpickle \
|
||||
portpicker \
|
||||
matplotlib \
|
||||
|
@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
|
||||
### Simplified Build Script
|
||||
|
||||
For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
|
||||
|
||||
|
@ -9,13 +9,13 @@ ARG ROCM_BUILD_NUM
|
||||
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
|
||||
# to find them.
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64
|
||||
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 numactl-devel
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
|
||||
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
|
||||
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install LLVM 18 and dependencies.
|
||||
|
@ -127,6 +127,31 @@ def dist_wheels(
|
||||
]
|
||||
)
|
||||
|
||||
# Add command for unit tests
|
||||
cmd.extend(
|
||||
[
|
||||
"&&",
|
||||
"bazel",
|
||||
"test",
|
||||
"-k",
|
||||
"--jobs=4",
|
||||
"--test_verbose_timeout_warnings=true",
|
||||
"--test_output=all",
|
||||
"--test_summary=detailed",
|
||||
"--local_test_jobs=1",
|
||||
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
|
||||
"--test_env=JAX_SKIP_SLOW_TESTS=0",
|
||||
"--verbose_failures=true",
|
||||
"--config=rocm",
|
||||
"--action_env=ROCM_PATH=/opt/rocm",
|
||||
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
|
||||
"--test_tag_filters=-multiaccelerator",
|
||||
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
|
||||
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
|
||||
"//tests:gpu_tests",
|
||||
]
|
||||
)
|
||||
|
||||
LOG.info("Running: %s", cmd)
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
@ -356,3 +381,4 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCM
|
||||
@ -47,7 +47,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
pytest-reportlog \
|
||||
pytest-rerunfailures \
|
||||
pytest-json-report \
|
||||
pytest-csv \
|
||||
cloudpickle \
|
||||
portpicker \
|
||||
matplotlib \
|
||||
|
@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
|
||||
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCM
|
||||
@ -46,7 +46,6 @@ RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
pytest-reportlog \
|
||||
pytest-rerunfailures \
|
||||
pytest-json-report \
|
||||
pytest-csv \
|
||||
cloudpickle \
|
||||
portpicker \
|
||||
matplotlib \
|
||||
|
@ -54,11 +54,15 @@ run_tests() {
|
||||
|
||||
python3 -m pytest \
|
||||
--html="${LOG_DIR}/multi_gpu_pmap_test_log.html" \
|
||||
--json-report \
|
||||
--json-report-file="${LOG_DIR}/multi_gpu_pmap_test_log.json" \
|
||||
--reruns 3 \
|
||||
tests/pmap_test.py
|
||||
|
||||
python3 -m pytest \
|
||||
--html="${LOG_DIR}/multi_gpu_multi_device_test_log.html" \
|
||||
--json-report \
|
||||
--json-report-file="${LOG_DIR}/multi_gpu_multi_device_test_log.json" \
|
||||
--reruns 3 \
|
||||
tests/multi_device_test.py
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import argparse
|
||||
import threading
|
||||
@ -43,22 +42,6 @@ def combine_json_reports():
|
||||
with open(combined_json_file, "w") as outfile:
|
||||
json.dump(combined_data, outfile, indent=4)
|
||||
|
||||
|
||||
def combine_csv_reports():
|
||||
all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")]
|
||||
combined_csv_file = f"{base_dir}/final_compiled_report.csv"
|
||||
with open(combined_csv_file, mode="w", newline="") as outfile:
|
||||
csv_writer = csv.writer(outfile)
|
||||
for i, csv_file in enumerate(all_csv_files):
|
||||
with open(os.path.join(base_dir, csv_file), mode="r") as infile:
|
||||
csv_reader = csv.reader(infile)
|
||||
if i == 0:
|
||||
# write headers only once
|
||||
csv_writer.writerow(next(csv_reader))
|
||||
for row in csv_reader:
|
||||
csv_writer.writerow(row)
|
||||
|
||||
|
||||
def generate_final_report(shell=False, env_vars={}):
|
||||
env = os.environ
|
||||
env = {**env, **env_vars}
|
||||
@ -76,8 +59,6 @@ def generate_final_report(shell=False, env_vars={}):
|
||||
|
||||
# Generate json reports.
|
||||
combine_json_reports()
|
||||
# Generate csv reports.
|
||||
combine_csv_reports()
|
||||
|
||||
|
||||
def run_shell_command(cmd, shell=False, env_vars={}):
|
||||
@ -147,9 +128,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
@ -163,9 +141,6 @@ def run_test(testmodule, gpu_tokens, continue_on_fail):
|
||||
"pytest",
|
||||
"--json-report",
|
||||
f"--json-report-file={base_dir}/{testfile}_log.json",
|
||||
f"--csv={base_dir}/{testfile}_log.csv",
|
||||
"--csv-columns",
|
||||
"id,module,name,file,status,duration",
|
||||
f"--html={base_dir}/{testfile}_log.html",
|
||||
"--reruns",
|
||||
"3",
|
||||
|
@ -16,77 +16,76 @@ ROCM_BUILD_NUM=main
|
||||
# Intial release don't have the trialing '.0'
|
||||
# For example ROCM 5.7.0 is at https://repo.radeon.com/rocm/apt/5.7/
|
||||
if [ ${ROCM_VERSION##*[^0-9]} -eq '0' ]; then
|
||||
ROCM_VERS=${ROCM_VERSION%.*}
|
||||
ROCM_VERS=${ROCM_VERSION%.*}
|
||||
else
|
||||
ROCM_VERS=$ROCM_VERSION
|
||||
ROCM_VERS=$ROCM_VERSION
|
||||
fi
|
||||
ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/
|
||||
|
||||
if [ ! -f "/${CUSTOM_INSTALL}" ]; then
|
||||
# Add rocm repository
|
||||
chmod 1777 /tmp
|
||||
DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update
|
||||
DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common
|
||||
DEBIAN_FRONTEND=noninteractive apt-get clean all
|
||||
wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
|
||||
if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]] ; then \
|
||||
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
|
||||
else \
|
||||
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
|
||||
fi
|
||||
#Install rocm and other packages
|
||||
apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
build-essential \
|
||||
software-properties-common \
|
||||
clang-6.0 \
|
||||
clang-format-6.0 \
|
||||
curl \
|
||||
g++-multilib \
|
||||
git \
|
||||
vim \
|
||||
libnuma-dev \
|
||||
virtualenv \
|
||||
python3-pip \
|
||||
pciutils \
|
||||
python-is-python3 \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
zlib1g-dev \
|
||||
libbz2-dev \
|
||||
libreadline-dev \
|
||||
libsqlite3-dev curl \
|
||||
libncursesw5-dev \
|
||||
xz-utils \
|
||||
tk-dev \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libffi-dev \
|
||||
liblzma-dev \
|
||||
wget \
|
||||
rocm-dev \
|
||||
rocm-libs \
|
||||
miopen-hip \
|
||||
miopen-hip-dev \
|
||||
rocblas \
|
||||
rocblas-dev \
|
||||
rocsolver-dev \
|
||||
rocrand-dev \
|
||||
rocfft-dev \
|
||||
hipfft-dev \
|
||||
hipblas-dev \
|
||||
rocprim-dev \
|
||||
hipcub-dev \
|
||||
rccl-dev \
|
||||
hipsparse-dev \
|
||||
hipsolver-dev \
|
||||
wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add rocm repository
|
||||
chmod 1777 /tmp
|
||||
DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update
|
||||
DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common
|
||||
DEBIAN_FRONTEND=noninteractive apt-get clean all
|
||||
wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
|
||||
if [[ $ROCM_DEB_REPO == https://repo.radeon.com/rocm/* ]]; then
|
||||
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
|
||||
else
|
||||
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" >/etc/apt/sources.list.d/rocm.list
|
||||
fi
|
||||
#Install rocm and other packages
|
||||
apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
build-essential \
|
||||
software-properties-common \
|
||||
clang-6.0 \
|
||||
clang-format-6.0 \
|
||||
curl \
|
||||
g++-multilib \
|
||||
git \
|
||||
vim \
|
||||
libnuma-dev \
|
||||
virtualenv \
|
||||
python3-pip \
|
||||
pciutils \
|
||||
python-is-python3 \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
zlib1g-dev \
|
||||
libbz2-dev \
|
||||
libreadline-dev \
|
||||
libsqlite3-dev curl \
|
||||
libncursesw5-dev \
|
||||
xz-utils \
|
||||
tk-dev \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libffi-dev \
|
||||
liblzma-dev \
|
||||
wget \
|
||||
rocm-dev \
|
||||
rocm-libs \
|
||||
miopen-hip \
|
||||
miopen-hip-dev \
|
||||
rocblas \
|
||||
rocblas-dev \
|
||||
rocsolver-dev \
|
||||
rocrand-dev \
|
||||
rocfft-dev \
|
||||
hipfft-dev \
|
||||
hipblas-dev \
|
||||
rocprim-dev \
|
||||
hipcub-dev \
|
||||
rccl-dev \
|
||||
hipsparse-dev \
|
||||
hipsolver-dev \
|
||||
wget &&
|
||||
apt-get clean &&
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
else
|
||||
bash "/${CUSTOM_INSTALL}"
|
||||
bash "/${CUSTOM_INSTALL}"
|
||||
fi
|
||||
|
||||
echo $ROCM_VERSION
|
||||
@ -95,6 +94,6 @@ echo $ROCM_PATH
|
||||
echo $GPU_DEVICE_TARGETS
|
||||
|
||||
# Ensure the ROCm target list is set up
|
||||
GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"}
|
||||
GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS:-"gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"}
|
||||
printf '%s\n' ${GPU_DEVICE_TARGETS} | tee -a "$ROCM_PATH/bin/target.lst"
|
||||
touch "${ROCM_PATH}/.info/version"
|
||||
|
@ -35,7 +35,7 @@ import sys
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1200 gfx1201"
|
||||
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
|
||||
|
||||
|
||||
def build_rocm_path(rocm_version_str):
|
||||
|
53
build/rocm/upload_wheels.sh
Normal file
53
build/rocm/upload_wheels.sh
Normal file
@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check for user-supplied arguments.
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "Usage: $0 <jax_home_directory> <version>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set JAX_HOME and RELEASE_VERSION from user arguments.
|
||||
JAX_HOME=$1
|
||||
RELEASE_VERSION=$2
|
||||
WHEELHOUSE="$JAX_HOME/wheelhouse"
|
||||
|
||||
# Projects to upload separately to PyPI.
|
||||
PROJECTS=("jax_rocm60_pjrt" "jax_rocm60_plugin")
|
||||
|
||||
# PyPI API Token.
|
||||
PYPI_API_TOKEN=${PYPI_API_TOKEN:-"pypi-replace_with_token"}
|
||||
|
||||
# Ensure the specified JAX_HOME and wheelhouse directories exists.
|
||||
if [[ ! -d "$JAX_HOME" ]]; then
|
||||
echo "Error: The specified JAX_HOME directory does not exist: $JAX_HOME"
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -d "$WHEELHOUSE" ]]; then
|
||||
echo "Error: The wheelhouse directory does not exist: $WHEELHOUSE"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
upload_and_release_project() {
|
||||
local project=$1
|
||||
|
||||
echo "Searching for wheels matching project: $project version: $RELEASE_VERSION..."
|
||||
wheels=($(ls $WHEELHOUSE | grep "^${project}-${RELEASE_VERSION}[.-].*\.whl"))
|
||||
if [[ ${#wheels[@]} -eq 0 ]]; then
|
||||
echo "No wheels found for project: $project version: $RELEASE_VERSION. Skipping..."
|
||||
return
|
||||
fi
|
||||
echo "Found wheels for $project: ${wheels[*]}"
|
||||
|
||||
echo "Uploading wheels for $project version $RELEASE_VERSION to PyPI..."
|
||||
for wheel in "${wheels[@]}"; do
|
||||
twine upload --verbose --repository pypi --non-interactive --username "__token__" --password "$PYPI_API_TOKEN" "$WHEELHOUSE/$wheel"
|
||||
done
|
||||
}
|
||||
|
||||
# Install twine if not already installed.
|
||||
python -m pip install --upgrade twine
|
||||
|
||||
# Iterate over each project and upload its wheels.
|
||||
for project in "${PROJECTS[@]}"; do
|
||||
upload_and_release_project $project
|
||||
done
|
65
rocm-downstream-dev-guide.md
Normal file
65
rocm-downstream-dev-guide.md
Normal file
@ -0,0 +1,65 @@
|
||||
# ROCm CI Dev Guide
|
||||
|
||||
This guide lays out how to do some dev operations, what branches live in this repo, and what CI workflows live in this repo.
|
||||
|
||||
# Quick Tips
|
||||
|
||||
1. Always use "Squash and Merge" when merging PRs into `rocm-main` (unless you're merging the daily sync from upstream).
|
||||
2. When submitting a PR to `rocm-main`, make sure that your feature branch started from `rocm-main`. When you started working on your feature, did you do `git checkout rocm-main && git checkout -b <my feature branch>`?
|
||||
3. Always fill out your PR's description with an explanation of why we need this change to give context to your fellow devs (and your future self).
|
||||
4. In the PR description, link to the story or GitHub issue that this PR solves.
|
||||
|
||||
# Processes
|
||||
|
||||
## Making a Change
|
||||
|
||||
1. Clone `rocm/jax` and check out the `rocm-main` branch.
|
||||
2. Create a new feature branch with `git checkout -b <my feature name>`.
|
||||
3. Make your changes on the feature branch and test them locally.
|
||||
4. Push your changes to a new feature branch in `rocm/jax` by running
|
||||
`git push orgin HEAD`.
|
||||
5. Open a PR from your new feature branch into `rocm-main` with a nice description telling your
|
||||
team members what the change is for. Bonus points if you can link it to an issue or story.
|
||||
6. Add reviewers, wait for approval, and make sure CI passes.
|
||||
7. Depending on if your specific change, either:
|
||||
a. If this is a normal, run-of-the-mill change that we want to put upstream, add the
|
||||
`open-upstream` label to your PR and close your PR. In a few minutes, Actions will
|
||||
comment on your PR with a link that lets you open a new PR into upstream. The link will
|
||||
autofill some PR info, and the new PR be created on a new branch that has the same name
|
||||
as your old feature branch, but with the `-upstream` suffix appended to the end of it.
|
||||
If upstream reviewers request some changes to the new PR before merging, you can add
|
||||
or modify commits on the new `-upstream` feature branch.
|
||||
b. If this is an urgent change that we want in `rocm-main` right now but also want upstream,
|
||||
add the `open-upstream` label, merge your PR, and then follow the link that
|
||||
c. If this is a change that we only want to keep in `rocm/jax` and not push into upstream,
|
||||
squash and merge your PR.
|
||||
|
||||
If you submitted your PR upstream with `open-upstream`, you should see your change in `rocm-main`
|
||||
the next time the `ROCm Nightly Upstream Sync` workflow is run and the PR that it creates is
|
||||
merged.
|
||||
|
||||
When using the `open-upstream` label to move changes to upstream, it's best to put the label on the PR when you either close or merge the PR. The GitHub Actions workflow that handles the `open-upstream` label uses `git rebase --onto` to set up the changes destined for upstream. Adding the label and creating this branch long after the PR has been merged or closed can cause merge conflicts with new upstream code and cause the workflow to fail. Adding the label right after creating your PR means that 1) any changes you make to your downstream PR while it is in review won't make it to upstream, and it is up to you to cherry-pick those changes into the upstream branch or remove and re-add the `open-upstream` label to get the Actions workflow to do it for you, and 2) that you're proposing changes to upstream that the rest of the AMD team might still have comments on.
|
||||
|
||||
## Daily Upstream Sync
|
||||
|
||||
Every day, GitHub Actions will attempt to run the `ROCm Nightly Upstream Sync` workflow. This job
|
||||
normally does this on its own, but requires a developer to intervene if there's a merge conflict
|
||||
or if the PR fails CI. Devs should fix or resolve issues with the merge by adding commits to the
|
||||
PR's branch.
|
||||
|
||||
# Branches
|
||||
|
||||
* `rocm-main` - the default "trunk" branch for this repo. Should only be changed submitting PRs to it from feature branches created by devs.
|
||||
* `main` - a copy of `jax-ml/jax:main`. This branch is "read-only" and should only be changed by GitHub Actions.
|
||||
|
||||
# CI Workflows
|
||||
|
||||
We use GitHub Actions to run tests on PRs and to automate some of our
|
||||
development tasks. These all live in `.github/workflows`.
|
||||
|
||||
| Name | File | Trigger | Description |
|
||||
|----------------------------|----------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------|
|
||||
| ROCm GPU CI | `rocm-ci.yml` | Open or commit changes to a PR targeting `rocm-main` | Builds and runs JAX on ROCm for PRs going into `rocm-main` |
|
||||
| ROCm Open Upstream PR | `rocm-open-upstream-pr.yml` | Add the `open-upstream` label to a PR | Copies changes from a PR aimed at `rocm-main` into a new PR aimed at upstream's `main` |
|
||||
| ROCm Nightly Upstream Sync | `rocm-nightly-upstream-sync.yml` | Runs nightly, can be triggered manually via Actions | Opens a PR that merges changes from upstream `main` into our `rocm-main` branch |
|
||||
|
@ -450,6 +450,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)
|
||||
|
||||
@unittest.skip(reason="https://github.com/jax-ml/jax/pull/25675")
|
||||
@jtu.sample_product(
|
||||
[dict(l_max=l_max, num_z=num_z)
|
||||
for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
|
||||
|
@ -262,6 +262,7 @@ class PallasCallRemoteDMAInterpretTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(('left',), ('right',))
|
||||
def test_interpret_remote_dma_ppermute(self, permutation):
|
||||
self.skipTest("ROCm: Skipping for now")
|
||||
if jax.device_count() <= 1:
|
||||
self.skipTest('Test requires multiple devices.')
|
||||
num_devices = jax.device_count()
|
||||
|
Loading…
x
Reference in New Issue
Block a user