From a9e54b3e0a0e01300b983679c1609df14f20aa40 Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Tue, 20 Aug 2024 15:29:43 -0500 Subject: [PATCH] Add docker builds for ubu22 and 24 --- build/rocm/Dockerfile.ms | 36 +++++++++++---- build/rocm/docker/Dockerfile.jax-ubu22 | 64 ++++++++++++++++++++++++++ build/rocm/docker/Dockerfile.jax-ubu24 | 63 +++++++++++++++++++++++++ build/rocm/docker/Makefile | 20 ++++++++ 4 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 build/rocm/docker/Dockerfile.jax-ubu22 create mode 100644 build/rocm/docker/Dockerfile.jax-ubu24 create mode 100644 build/rocm/docker/Makefile diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index dffe42de7..0bcc89f49 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -3,10 +3,10 @@ FROM ubuntu:20.04 AS rocm_base ################################################################################ RUN --mount=type=cache,target=/var/cache/apt \ - apt-get update && apt-get install -y python3 + apt-get update && apt-get install -y python3 python-is-python3 # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM @@ -16,6 +16,7 @@ ENV ROCM_PATH=${ROCM_PATH} ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM # Set up paths @@ -42,14 +43,30 @@ RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH RUN pyenv install $PYTHON_VERSION -RUN eval "$(pyenv init -)" && \ +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + eval "$(pyenv init -)" && \ pyenv local ${PYTHON_VERSION} && \ pip3 install --upgrade --force-reinstall setuptools pip && \ - pip install \ - numpy setuptools build wheel six auditwheel scipy \ - pytest pytest-html pytest_html_merger pytest-reportlog \ - pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \ - flatbuffers hypothesis pytest-json-report pytest-csv + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis ################################################################################ FROM rocm_base AS rt_build @@ -65,6 +82,7 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ com.amdgpu.jax_commit="$JAX_COMMIT" \ com.amdgpu.xla_commit="$XLA_COMMIT" -RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \ +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 new file mode 100644 index 000000000..ba64efbbc --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -0,0 +1,64 @@ +FROM ubuntu:22.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.0.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.10" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Dockerfile.jax-ubu24 b/build/rocm/docker/Dockerfile.jax-ubu24 new file mode 100644 index 000000000..44c59b1b7 --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu24 @@ -0,0 +1,63 @@ +FROM ubuntu:24.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 python3-pip + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.2.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --break-system-packages \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.12" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Makefile b/build/rocm/docker/Makefile new file mode 100644 index 000000000..7fb38a936 --- /dev/null +++ b/build/rocm/docker/Makefile @@ -0,0 +1,20 @@ +.PHONY: all clean + +all: .docker-jax-ubu22 .docker-jax-ubu24 + +clean: clean-jax-ubu22 clean-jax-ubu24 + +ROCM_VERSION = 6.2.0 + +.docker-% : build/rocm/docker/Dockerfile.% + docker build -f $< --tag $(*F) --progress plain \ + --build-arg=ROCM_VERSION=${ROCM_VERSION} \ + --build-arg=JAX_VERSION=$(shell python setup.py -V) \ + --build-arg=JAX_COMMIT=$(shell git rev-parse HEAD) \ + . + @touch $@ + + +clean-%: + -docker rmi $(*F) + @rm -f .docker-$(*F)