mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
add CPU-only build option
This commit is contained in:
parent
46c6a9170f
commit
7773167995
@ -1,6 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -exv
|
||||
|
||||
# For a build with CUDA, from the repo root run:
|
||||
# bash build/build_jax.sh
|
||||
# For building without CUDA (CPU-only), instead run:
|
||||
# JAX_BUILD_WITH_CUDA=0 bash build/build_jax.sh
|
||||
# To clean intermediate results, run
|
||||
# rm -rf /tmp/jax-build/jax-bazel-output-user-root
|
||||
# To clean everything, run
|
||||
# rm -rf /tmp/jax-build
|
||||
|
||||
JAX_BUILD_WITH_CUDA=${JAX_BUILD_WITH_CUDA:-1}
|
||||
|
||||
init_commit=a30e858e59d7184b9e54dc3f3955238221d70439
|
||||
if [[ ! -d .git || $(git rev-list --parents HEAD | tail -1) != ${init_commit} ]]
|
||||
then
|
||||
@ -25,7 +36,7 @@ export PATH="${bazel_dir}/bin:$PATH"
|
||||
# BUG: https://github.com/bazelbuild/bazel/issues/6665
|
||||
handle_temporary_bazel_0_19_1_bug=1 # TODO(mattjj): remove with bazel 0.19.2
|
||||
|
||||
## get and configure tensorflow for building xla with gpu support
|
||||
## get and configure tensorflow for building xla
|
||||
if [[ ! -d tensorflow ]]
|
||||
then
|
||||
git clone https://github.com/tensorflow/tensorflow.git
|
||||
@ -34,12 +45,17 @@ pushd tensorflow
|
||||
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-$(which python)}
|
||||
export PYTHON_LIB_PATH=${SP_DIR:-$(python -m site --user-site)}
|
||||
export USE_DEFAULT_PYTHON_LIB_PATH=1
|
||||
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
|
||||
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
|
||||
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
|
||||
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
|
||||
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
|
||||
export TF_NEED_CUDA=1
|
||||
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
|
||||
then
|
||||
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
|
||||
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
|
||||
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
|
||||
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
|
||||
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
|
||||
export TF_NEED_CUDA=1
|
||||
else
|
||||
export TF_NEED_CUDA=0
|
||||
fi
|
||||
export GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
|
||||
export TF_ENABLE_XLA=1
|
||||
export TF_NEED_MKL=0
|
||||
@ -62,8 +78,14 @@ mkdir -p ${PYTHON_LIB_PATH}
|
||||
bazel_output_user_root=${tmp}/jax-bazel-output-user-root
|
||||
bazel_output_base=${bazel_output_user_root}/output-base
|
||||
bazel_opt="--output_user_root=${bazel_output_user_root} --output_base=${bazel_output_base} --bazelrc=tensorflow/tools/bazel.rc"
|
||||
bazel_build_opt="-c opt --config=cuda"
|
||||
if [ -n $handle_temporary_bazel_0_19_1_bug ]
|
||||
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
|
||||
then
|
||||
bazel_build_opt="-c opt --config=cuda"
|
||||
else
|
||||
bazel_build_opt="-c opt"
|
||||
fi
|
||||
# TODO(mattjj): remove this if/else clause with bazel 0.19.2 release
|
||||
if [[ -n $handle_temporary_bazel_0_19_1_bug && ${JAX_BUILD_WITH_CUDA} != 0 ]]
|
||||
then
|
||||
set +e
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax 2> /dev/null
|
||||
@ -74,9 +96,9 @@ bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax
|
||||
|
||||
## extract the pieces we need
|
||||
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/jax/build_jax.runfiles/org_tensorflow/tensorflow"
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
|
||||
cp ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
|
||||
|
||||
## rewrite some imports
|
||||
sed -i 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' jax/lib/xla_client.py
|
||||
@ -86,4 +108,5 @@ sed -i '/from tensorflow.compiler.xla.service import hlo_pb2/d' jax/lib/xla_clie
|
||||
## clean up
|
||||
rm -f bazel-* # symlinks
|
||||
rm -rf tensorflow
|
||||
# rm -rf ${tmp}
|
||||
rm -rf ${bazel_output_user_root} # clean build results
|
||||
# rm -rf ${tmp} # clean everything, including the bazel binary
|
||||
|
Loading…
x
Reference in New Issue
Block a user