rocm_jax/build/rocm/docker/Dockerfile.jax-ubu24

71 lines
2.1 KiB
Docker

FROM ubuntu:24.04
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3 python3-pip
# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM
ARG ROCM_VERSION=6.2.0
ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION}
ENV ROCM_PATH=${ROCM_PATH}
ARG ROCM_BUILD_JOB
ARG ROCM_BUILD_NUM
RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
--mount=type=cache,target=/var/cache/apt \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
# Set up paths
ENV HCC_HOME=$ROCM_PATH/hcc
ENV HIP_PATH=$ROCM_PATH/
ENV OPENCL_ROOT=$ROCM_PATH/opencl
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
ENV PATH="/root/bin:/root/.local/bin:$PATH"
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
pip3 install --break-system-packages \
"numpy<2" \
build \
wheel \
six \
auditwheel \
scipy \
pytest \
pytest-html \
pytest_html_merger \
pytest-reportlog \
pytest-rerunfailures \
pytest-json-report \
cloudpickle \
portpicker \
matplotlib \
absl-py \
flatbuffers \
hypothesis
ARG JAX_VERSION
ARG JAX_COMMIT
ARG XLA_COMMIT
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
com.amdgpu.python_version="3.12" \
com.amdgpu.jax_version="$JAX_VERSION" \
com.amdgpu.jax_commit="$JAX_COMMIT" \
com.amdgpu.xla_commit="$XLA_COMMIT"
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl