1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

add CPU-only build option

This commit is contained in:
Matthew Johnson 2018-11-19 08:30:37 -08:00
parent 46c6a9170f
commit 7773167995

@ -1,6 +1,17 @@
#!/bin/bash
set -exv
# For a build with CUDA, from the repo root run:
# bash build/build_jax.sh
# For building without CUDA (CPU-only), instead run:
# JAX_BUILD_WITH_CUDA=0 bash build/build_jax.sh
# To clean intermediate results, run
# rm -rf /tmp/jax-build/jax-bazel-output-user-root
# To clean everything, run
# rm -rf /tmp/jax-build
JAX_BUILD_WITH_CUDA=${JAX_BUILD_WITH_CUDA:-1}
init_commit=a30e858e59d7184b9e54dc3f3955238221d70439
if [[ ! -d .git || $(git rev-list --parents HEAD | tail -1) != ${init_commit} ]]
then
@ -25,7 +36,7 @@ export PATH="${bazel_dir}/bin:$PATH"
# BUG: https://github.com/bazelbuild/bazel/issues/6665
handle_temporary_bazel_0_19_1_bug=1 # TODO(mattjj): remove with bazel 0.19.2
## get and configure tensorflow for building xla with gpu support
## get and configure tensorflow for building xla
if [[ ! -d tensorflow ]]
then
git clone https://github.com/tensorflow/tensorflow.git
@ -34,12 +45,17 @@ pushd tensorflow
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-$(which python)}
export PYTHON_LIB_PATH=${SP_DIR:-$(python -m site --user-site)}
export USE_DEFAULT_PYTHON_LIB_PATH=1
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
export TF_NEED_CUDA=1
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
then
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
export TF_NEED_CUDA=1
else
export TF_NEED_CUDA=0
fi
export GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
export TF_ENABLE_XLA=1
export TF_NEED_MKL=0
@ -62,8 +78,14 @@ mkdir -p ${PYTHON_LIB_PATH}
bazel_output_user_root=${tmp}/jax-bazel-output-user-root
bazel_output_base=${bazel_output_user_root}/output-base
bazel_opt="--output_user_root=${bazel_output_user_root} --output_base=${bazel_output_base} --bazelrc=tensorflow/tools/bazel.rc"
bazel_build_opt="-c opt --config=cuda"
if [ -n $handle_temporary_bazel_0_19_1_bug ]
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
then
bazel_build_opt="-c opt --config=cuda"
else
bazel_build_opt="-c opt"
fi
# TODO(mattjj): remove this if/else clause with bazel 0.19.2 release
if [[ -n $handle_temporary_bazel_0_19_1_bug && ${JAX_BUILD_WITH_CUDA} != 0 ]]
then
set +e
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax 2> /dev/null
@ -74,9 +96,9 @@ bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax
## extract the pieces we need
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/jax/build_jax.runfiles/org_tensorflow/tensorflow"
cp ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
cp -f ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
## rewrite some imports
sed -i 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' jax/lib/xla_client.py
@ -86,4 +108,5 @@ sed -i '/from tensorflow.compiler.xla.service import hlo_pb2/d' jax/lib/xla_clie
## clean up
rm -f bazel-* # symlinks
rm -rf tensorflow
# rm -rf ${tmp}
rm -rf ${bazel_output_user_root} # clean build results
# rm -rf ${tmp} # clean everything, including the bazel binary