mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[ROCm]: Updates for container and build script
-Updated dockerfile.ms -Updated build script to switch building against XLA repo -Update CI script -Update jaxlib setup.py to add rocm version
This commit is contained in:
parent
94674b9ae1
commit
b0e541a730
@ -2,12 +2,26 @@
|
||||
FROM rocm/dev-ubuntu-20.04:5.4-complete as rt_build
|
||||
MAINTAINER Rahul Batra<rahbatra@amd.com>
|
||||
################################################################################
|
||||
ARG ROCM_PATH=/opt/rocm-5.4.0
|
||||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.5/
|
||||
ARG ROCM_BUILD_NAME=ubuntu
|
||||
ARG ROCM_BUILD_NUM=main
|
||||
ARG ROCM_PATH=/opt/rocm-5.5.0
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG PYTHON_VERSION=3.9.0
|
||||
ENV HOME /root/
|
||||
ENV ROCM_PATH=$ROCM_PATH
|
||||
|
||||
RUN apt-get --allow-unauthenticated update && apt install -y wget software-properties-common
|
||||
RUN apt-get clean all
|
||||
RUN wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
|
||||
RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; then \
|
||||
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
|
||||
else \
|
||||
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
|
||||
fi'
|
||||
|
||||
|
||||
RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
build-essential \
|
||||
software-properties-common \
|
||||
@ -21,10 +35,34 @@ RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteracti
|
||||
virtualenv \
|
||||
python3-pip \
|
||||
pciutils \
|
||||
python-is-python3 \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
build-essential \
|
||||
zlib1g-dev \
|
||||
libbz2-dev \
|
||||
libreadline-dev \
|
||||
libsqlite3-dev curl \
|
||||
libncursesw5-dev \
|
||||
xz-utils \
|
||||
tk-dev \
|
||||
libxml2-dev \
|
||||
libxmlsec1-dev \
|
||||
libffi-dev \
|
||||
liblzma-dev \
|
||||
wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add to get ppa
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y software-properties-common
|
||||
# Install rocm pkgs
|
||||
RUN apt-get update --allow-insecure-repositories && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
rocm-dev rocm-libs rccl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set up paths
|
||||
ENV HCC_HOME=$ROCM_PATH/hcc
|
||||
@ -55,26 +93,14 @@ ENV PATH="/root/bin:/root/.local/bin:$PATH"
|
||||
|
||||
|
||||
# Install python3.9
|
||||
RUN add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt update && \
|
||||
apt install -y python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-distutils \
|
||||
python-is-python3
|
||||
# Install pyenv with different python versions
|
||||
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
|
||||
ENV PYENV_ROOT /pyenv
|
||||
ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
|
||||
|
||||
RUN pip3 install --upgrade --force-reinstall setuptools pip
|
||||
RUN pyenv install $PYTHON_VERSION
|
||||
|
||||
RUN pip3 install absl-py numpy==1.20.0 scipy wheel six setuptools pytest pytest-rerunfailures matplotlib
|
||||
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip==22.0 && pip install numpy==1.21.0 setuptools wheel six auditwheel scipy pytest pytest-rerunfailures matplotlib absl-py
|
||||
|
||||
# Get jax and build it with ROCm
|
||||
RUN git clone https://github.com/google/jax.git
|
||||
|
||||
################################################################################
|
||||
FROM rt_build as ci_build
|
||||
################################################################################
|
||||
WORKDIR /jax
|
||||
RUN ./build/rocm/build_rocm.sh
|
||||
RUN ./build/rocm/run_single_gpu.py
|
||||
RUN ./build/rocm/run_multi_gpu.sh
|
||||
|
@ -1,90 +0,0 @@
|
||||
FROM ubuntu:focal
|
||||
MAINTAINER Rahul Batra<rahbatra@amd.com>
|
||||
|
||||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.4/
|
||||
ARG ROCM_BUILD_NAME=ubuntu
|
||||
ARG ROCM_BUILD_NUM=main
|
||||
ARG ROCM_PATH=/opt/rocm-5.4.0
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV HOME /root/
|
||||
ENV ROCM_PATH=$ROCM_PATH
|
||||
|
||||
RUN apt-get --allow-unauthenticated update && apt install -y wget software-properties-common
|
||||
RUN apt-get clean all
|
||||
RUN wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -;
|
||||
RUN bin/bash -c 'if [[ $ROCM_DEB_REPO == http://repo.radeon.com/rocm/* ]] ; then \
|
||||
echo "deb [arch=amd64] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list; \
|
||||
else \
|
||||
echo "deb [arch=amd64 trusted=yes] $ROCM_DEB_REPO $ROCM_BUILD_NAME $ROCM_BUILD_NUM" > /etc/apt/sources.list.d/rocm.list ; \
|
||||
fi'
|
||||
|
||||
|
||||
RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
|
||||
build-essential \
|
||||
software-properties-common \
|
||||
clang-6.0 \
|
||||
clang-format-6.0 \
|
||||
curl \
|
||||
g++-multilib \
|
||||
git \
|
||||
vim \
|
||||
libnuma-dev \
|
||||
virtualenv \
|
||||
python3-pip \
|
||||
pciutils \
|
||||
wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add to get ppa
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y software-properties-common
|
||||
# Install rocm pkgs
|
||||
RUN apt-get update --allow-insecure-repositories && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
rocm-dev rocm-libs rccl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set up paths
|
||||
ENV HCC_HOME=$ROCM_PATH/hcc
|
||||
ENV HIP_PATH=$ROCM_PATH/hip
|
||||
ENV OPENCL_ROOT=$ROCM_PATH/opencl
|
||||
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
|
||||
ENV PATH="$ROCM_PATH/bin:${PATH}"
|
||||
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
RUN bash -c 'echo -e "gfx900\ngfx906\ngfx908\ngfx90a\ngfx1030" >> ${ROCM_PATH}/bin/target.lst'
|
||||
|
||||
# Need to explicitly create the $ROCM_PATH/.info/version file to workaround what seems to be a bazel bug
|
||||
# The env vars being set via --action_env in .bazelrc and .tf_configure.bazelrc files are sometimes
|
||||
# not getting set in the build command being spawned by bazel (in theory this should not happen)
|
||||
# As a consequence ROCM_PATH is sometimes not set for the hipcc commands.
|
||||
# When hipcc incokes hcc, it specifies $ROCM_PATH/.../include dirs via the `-isystem` options
|
||||
# If ROCM_PATH is not set, it defaults to /opt/rocm, and as a consequence a dependency is generated on the
|
||||
# header files included within `/opt/rocm`, which then leads to bazel dependency errors
|
||||
# Explicitly creating the $ROCM_PATH/.info/version allows ROCM path to be set correrctly, even when ROCM_PATH
|
||||
# is not explicitly set, and thus avoids the eventual bazel dependency error.
|
||||
# The bazel bug needs to be root-caused and addressed, but that is out of our control and may take a long time
|
||||
# to come to fruition, so implementing the workaround to make do till then
|
||||
# Filed https://github.com/bazelbuild/bazel/issues/11163 for tracking this
|
||||
RUN touch ${ROCM_PATH}/.info/version
|
||||
|
||||
ENV PATH="/root/bin:/root/.local/bin:$PATH"
|
||||
|
||||
|
||||
# Install python3.9
|
||||
RUN add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt update && \
|
||||
apt install -y python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-distutils \
|
||||
python-is-python3
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
|
||||
|
||||
RUN pip3 install --upgrade --force-reinstall setuptools pip
|
||||
|
||||
RUN pip3 install absl-py numpy==1.20.0 scipy wheel six setuptools pytest pytest-rerunfailures matplotlib
|
@ -13,23 +13,55 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -eux
|
||||
# Environment Var Notes
|
||||
# XLA_CLONE_DIR -
|
||||
# Specifies filepath to where XLA repo is cloned.
|
||||
# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script.
|
||||
# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect.
|
||||
# XLA_REPO
|
||||
# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream
|
||||
# XLA_BRANCH
|
||||
# XLA branch in the XLA repo. Default is develop-upstream-jax
|
||||
#
|
||||
|
||||
ROCM_TF_FORK_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
|
||||
ROCM_TF_FORK_BRANCH="develop-upstream"
|
||||
set -eux
|
||||
python -V
|
||||
|
||||
#If XLA_REPO is not set, then use default
|
||||
if [ ! -v XLA_REPO ]; then
|
||||
XLA_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
|
||||
XLA_BRANCH="develop-upstream-jax"
|
||||
elif [ -z "$XLA_REPO" ]; then
|
||||
XLA_REPO="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream"
|
||||
XLA_BRANCH="develop-upstream-jax"
|
||||
fi
|
||||
|
||||
#If XLA_CLONE_PATH is not set, then use default path.
|
||||
#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op
|
||||
#Set this when XLA repository has been already clone. This is useful in CI
|
||||
#environments and when doing local development
|
||||
if [ ! -v XLA_CLONE_DIR ]; then
|
||||
XLA_CLONE_DIR=/tmp/tensorflow-upstream
|
||||
rm -rf /tmp/tensorflow-upstream || true
|
||||
git clone -b ${ROCM_TF_FORK_BRANCH} ${ROCM_TF_FORK_REPO} /tmp/tensorflow-upstream
|
||||
if [ ! -v TENSORFLOW_ROCM_COMMIT ]; then
|
||||
echo "The TENSORFLOW_ROCM_COMMIT environment variable is not set, using top of branch"
|
||||
elif [ ! -z "$TENSORFLOW_ROCM_COMMIT" ]
|
||||
then
|
||||
echo "Using tensorflow-rocm at commit: $TENSORFLOW_ROCM_COMMIT"
|
||||
cd /tmp/tensorflow-upstream
|
||||
git checkout $TENSORFLOW_ROCM_COMMIT
|
||||
cd -
|
||||
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/tensorflow-upstream
|
||||
elif [ -z "$XLA_CLONE_DIR" ]; then
|
||||
XLA_CLONE_DIR=/tmp/tensorflow-upstream
|
||||
rm -rf /tmp/tensorflow-upstream || true
|
||||
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/tensorflow-upstream
|
||||
fi
|
||||
|
||||
|
||||
python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=org_tensorflow=/tmp/tensorflow-upstream
|
||||
#Export JAX_ROCM_VERSION so that it is appened in the wheel name
|
||||
rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1)
|
||||
export JAX_ROCM_VERSION=${rocm_version//./}
|
||||
|
||||
#Build and install wheel
|
||||
python3 ./build/build.py --enable_rocm --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR}
|
||||
pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
|
||||
pip3 install --force-reinstall . # installs jax
|
||||
|
||||
#This is for CI to read without having to start the container again
|
||||
if [ -v CI_RUN ]; then
|
||||
pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed
|
||||
cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version
|
||||
fi
|
||||
|
@ -13,16 +13,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Usage: ci_build.sh [--dockerfile <DOCKERFILE_PATH> --keep_image]
|
||||
# Usage: ci_build.sh [--dockerfile <DOCKERFILE_PATH> --keep_image --py_version <PYTHON_VERSION>]
|
||||
# <COMMAND>
|
||||
#
|
||||
# DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docker build.
|
||||
# DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docer build.
|
||||
# If this optional value is not supplied (via the --dockerfile flag)
|
||||
# Dockerfile.rocm (located in the same directory as this script)
|
||||
# Dockerfile.ms (located in the same directory as this script)
|
||||
# will be used.
|
||||
# KEEP_IMAGE: (Optional) If this flag is set, the container will be committed as an image
|
||||
#
|
||||
# PYTHON_VERSION: Python version to use
|
||||
#
|
||||
# COMMAND: Command to be executed in the docker container
|
||||
#
|
||||
# Environment variables read by this script
|
||||
# WORKSPACE
|
||||
# XLA_REPO
|
||||
# XLA_BRANCH
|
||||
# XLA_CLONE_DIR
|
||||
# BUILD_TAG
|
||||
#
|
||||
|
||||
set -eux
|
||||
|
||||
@ -33,12 +43,17 @@ CONTAINER_TYPE="rocm"
|
||||
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms"
|
||||
DOCKER_CONTEXT_PATH="${SCRIPT_DIR}"
|
||||
KEEP_IMAGE="--rm"
|
||||
KEEP_CONTAINER="--rm"
|
||||
POSITIONAL_ARGS=()
|
||||
|
||||
RUNTIME_FLAG=0
|
||||
RUNTIME_FLAG=1
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--py_version)
|
||||
PYTHON_VERSION="$2"
|
||||
shift 2
|
||||
;;
|
||||
--dockerfile)
|
||||
DOCKERFILE_PATH="$2"
|
||||
DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}")
|
||||
@ -52,6 +67,11 @@ while [[ $# -gt 0 ]]; do
|
||||
RUNTIME_FLAG=1
|
||||
shift 1
|
||||
;;
|
||||
--keep_container)
|
||||
KEEP_CONTAINER=""
|
||||
shift 1
|
||||
;;
|
||||
|
||||
*)
|
||||
POSITIONAL_ARGS+=("$1")
|
||||
shift
|
||||
@ -78,7 +98,7 @@ WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
|
||||
BUILD_TAG="${BUILD_TAG:-jax}"
|
||||
|
||||
# Determine the docker image name and BUILD_TAG.
|
||||
DOCKER_IMG_NAME="${BUILD_TAG}_${CONTAINER_TYPE}"
|
||||
DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}"
|
||||
|
||||
# Under Jenkins matrix build, the build tag may contain characters such as
|
||||
# commas (,) and equal signs (=), which are not valid inside docker image names.
|
||||
@ -94,13 +114,17 @@ echo "BUILD_TAG: ${BUILD_TAG}"
|
||||
echo " (docker container name will be ${DOCKER_IMG_NAME})"
|
||||
echo ""
|
||||
|
||||
echo "Building container (${DOCKER_IMG_NAME})..."
|
||||
echo "Python Version (${PYTHON_VERSION})"
|
||||
if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then
|
||||
echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
|
||||
docker build --target rt_build --tag ${DOCKER_IMG_NAME} \
|
||||
--build-arg PYTHON_VERSION=$PYTHON_VERSION \
|
||||
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
|
||||
else
|
||||
echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
|
||||
docker build --target ci_build --tag ${DOCKER_IMG_NAME} \
|
||||
--build-arg PYTHON_VERSION=$PYTHON_VERSION \
|
||||
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
|
||||
fi
|
||||
|
||||
@ -112,22 +136,32 @@ fi
|
||||
# Run the command inside the container.
|
||||
echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..."
|
||||
|
||||
export TENSORFLOW_ROCM_COMMIT="${TENSORFLOW_ROCM_COMMIT:-}"
|
||||
export XLA_REPO="${XLA_REPO:-}"
|
||||
export XLA_BRANCH="${XLA_BRANCH:-}"
|
||||
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}"
|
||||
|
||||
if [ ! -z ${XLA_CLONE_DIR} ]; then
|
||||
ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}"
|
||||
fi
|
||||
|
||||
docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \
|
||||
-v ${WORKSPACE}:/workspace \
|
||||
-w /workspace \
|
||||
-e TENSORFLOW_ROCM_COMMIT=${TENSORFLOW_ROCM_COMMIT} \
|
||||
-e XLA_REPO=${XLA_REPO} \
|
||||
-e XLA_BRANCH=${XLA_BRANCH} \
|
||||
-e XLA_CLONE_DIR=${XLA_CLONE_DIR} \
|
||||
-e PYTHON_VERSION=$PYTHON_VERSION \
|
||||
-e CI_RUN=1 \
|
||||
${ROCM_EXTRA_PARAMS} \
|
||||
"${DOCKER_IMG_NAME}" \
|
||||
${POSITIONAL_ARGS[@]}
|
||||
|
||||
if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then
|
||||
echo "Committing the docker container as jax-rocm"
|
||||
echo "Committing the docker container as ${DOCKER_IMG_NAME}"
|
||||
docker stop ${DOCKER_IMG_NAME}
|
||||
docker commit ${DOCKER_IMG_NAME} jax-rocm
|
||||
docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME}
|
||||
docker rm ${DOCKER_IMG_NAME} # remove this temp container
|
||||
docker rmi ${DOCKER_IMG_NAME} # remote this temp image
|
||||
fi
|
||||
|
||||
echo "Jax-ROCm build was successful!"
|
||||
|
@ -30,6 +30,10 @@ cudnn_version = os.environ.get("JAX_CUDNN_VERSION")
|
||||
if cuda_version and cudnn_version:
|
||||
__version__ += f"+cuda{cuda_version.replace('.', '')}-cudnn{cudnn_version.replace('.', '')}"
|
||||
|
||||
rocm_version = os.environ.get("JAX_ROCM_VERSION")
|
||||
if rocm_version:
|
||||
__version__ += f"+rocm{rocm_version.replace('.', '')}"
|
||||
|
||||
class BinaryDistribution(Distribution):
|
||||
"""This class makes 'bdist_wheel' include an ABI tag on the wheel."""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user