FROM ubuntu:22.04 RUN --mount=type=cache,target=/var/cache/apt \ apt-get update && apt-get install -y python3 python-is-python3 # Install bzip2 and sqlite3 packages RUN apt-get update && apt-get install -y \ sqlite3 libsqlite3-dev \ libbz2-dev \ && rm -rf /var/lib/apt/lists/* # Add target file to help determine which device(s) to build for ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM ARG ROCM_VERSION=6.0.0 ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ --mount=type=cache,target=/var/cache/apt \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM # Set up paths ENV HCC_HOME=$ROCM_PATH/hcc ENV HIP_PATH=$ROCM_PATH/ ENV OPENCL_ROOT=$ROCM_PATH/opencl ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ pip3 install --upgrade --force-reinstall setuptools pip && \ pip3 install \ "numpy<2" \ build \ wheel \ six \ auditwheel \ scipy \ pytest \ pytest-html \ pytest_html_merger \ pytest-reportlog \ pytest-rerunfailures \ pytest-json-report \ cloudpickle \ portpicker \ matplotlib \ absl-py \ flatbuffers \ hypothesis ARG JAX_VERSION ARG JAX_COMMIT ARG XLA_COMMIT LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.python_version="3.10" \ com.amdgpu.jax_version="$JAX_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ --mount=type=bind,source=wheelhouse,target=/wheelhouse \ ls -lah /wheelhouse && \ pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \ pip3 install wheelhouse/*rocm60*.whl