mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Rewrite OSS build script.
Significant changes: * Mac OS X support. * build script is in Python, not shell. * build configuration is passed via flags, not environment variables. * build script configures TF itself, and does not require explicitly checking out the TF git repository and running its configure script. Changes the TF dependency in the Bazel workspace to be an http_archive(), rather than a local checkout of TF. * rather than trying to guess the path for Bazel-generated XLA artifacts, use a sh_binary() to perform installation of the built artifacts in to the JAX source tree. Bazel's runfiles mechanism is the supported route to find build artifacts. * downloads Bazel in Python and checks its SHA256 before running it, rather than running an untrusted binary from the internet. * intentionally does not delete the Bazel cache or Bazel after building. Example of new build interaction: Building without CUDA on Mac or Linux: $ cd jax $ python3 build.py (or python2 build.py if you want a Python 2 build) _ _ __ __ | | / \ \ \/ / _ | |/ _ \ \ / | |_| / ___ \ / \ \___/_/ \_\/_/\_\ Starting local Bazel server and connecting to it... Bazel binary path: /Users/xyz/bin/bazel Python binary path: /Library/Frameworks/Python.framework/Versions/3.7/bin/python3 CUDA enabled: no Building XLA and installing it in the JAX source tree... ... Example of building with CUDA enabled on Linux: $ python3 build.py --enable_cuda --cudnn_path=/usr/lib/x86_64-linux-gnu/ ... as before, except ... CUDA enabled: yes CUDA toolkit path: /usr/local/cuda CUDNN library path: /usr/lib/x86_64-linux-gnu/ ... PiperOrigin-RevId: 222868835
This commit is contained in:
parent
326773808b
commit
f3513a7bfb
27
WORKSPACE
27
WORKSPACE
@ -1,8 +1,3 @@
|
||||
local_repository(
|
||||
name = "org_tensorflow",
|
||||
path = "tensorflow",
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
|
||||
@ -13,6 +8,28 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# To update TensorFlow to a new revision,
|
||||
# a) update URL and strip_prefix to the new git commit hash
|
||||
# b) get the sha256 hash of the commit by running:
|
||||
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "599e9aad221a27882fce98ff472372030a2ebfe63cfc643d3470691f34bb68d6",
|
||||
strip_prefix="tensorflow-64e084b8cb27e8c53b15468c21f1b3471b4b9659",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/64e084b8cb27e8c53b15468c21f1b3471b4b9659.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
# For development, one can use a local TF repository instead.
|
||||
# local_repository(
|
||||
# name = "org_tensorflow",
|
||||
# path = "tensorflow",
|
||||
# )
|
||||
|
||||
|
||||
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
|
||||
|
||||
tf_workspace(
|
||||
|
279
build.py
Normal file
279
build.py
Normal file
@ -0,0 +1,279 @@
|
||||
#!/usr/bin/python
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Helper script for building JAX easily.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import hashlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
if hasattr(urllib, "urlretrieve"):
|
||||
urlretrieve = urllib.urlretrieve
|
||||
else:
|
||||
import urllib.request
|
||||
urlretrieve = urllib.request.urlretrieve
|
||||
|
||||
if hasattr(shutil, "which"):
|
||||
which = shutil.which
|
||||
else:
|
||||
from distutils.spawn import find_executable as which
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
def shell(cmd):
|
||||
output = subprocess.check_output(cmd)
|
||||
return output.decode("UTF-8").strip()
|
||||
|
||||
|
||||
# Python
|
||||
|
||||
def get_python_bin_path(python_bin_path_flag):
|
||||
"""Returns the path to the Python interpreter to use."""
|
||||
return python_bin_path_flag or sys.executable
|
||||
|
||||
|
||||
# Bazel
|
||||
|
||||
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/0.19.2/"
|
||||
BazelPackage = collections.namedtuple("BazelPackage", ["file", "sha256"])
|
||||
bazel_packages = {
|
||||
"Linux":
|
||||
BazelPackage(
|
||||
file="bazel-0.19.2-linux-x86_64",
|
||||
sha256=
|
||||
"2ee9f23b49fb47725f725579c47f4f50272f4f9d23643e32add1fdef6aa0c5e0"),
|
||||
"Darwin":
|
||||
BazelPackage(
|
||||
file="bazel-0.19.2-darwin-x86_64",
|
||||
sha256=
|
||||
"74ae65127b46b59305fc5ea0c6baca355fce7e87c8624448e06f8cf2366b507e"),
|
||||
}
|
||||
|
||||
|
||||
def download_and_verify_bazel():
|
||||
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
|
||||
package = bazel_packages.get(platform.system())
|
||||
if package is None:
|
||||
return None
|
||||
|
||||
if not os.access(package.file, os.X_OK):
|
||||
uri = BAZEL_BASE_URI + package.file
|
||||
sys.stdout.write("Downloading bazel from: {}\n".format(uri))
|
||||
|
||||
def progress(block_count, block_size, total_size):
|
||||
if total_size <= 0:
|
||||
total_size = 170**6
|
||||
progress = (block_count * block_size) / total_size
|
||||
num_chars = 40
|
||||
progress_chars = int(num_chars * progress)
|
||||
sys.stdout.write("{} [{}{}] {}%\r".format(
|
||||
package.file, "#" * progress_chars,
|
||||
"." * (num_chars - progress_chars), int(progress * 100.0)))
|
||||
|
||||
tmp_path, _ = urlretrieve(uri, None, progress)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
# Verify that the downloaded Bazel binary has the expected SHA256.
|
||||
downloaded_file = open(tmp_path, "rb")
|
||||
contents = downloaded_file.read()
|
||||
downloaded_file.close()
|
||||
digest = hashlib.sha256(contents).hexdigest()
|
||||
if digest != package.sha256:
|
||||
print(
|
||||
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
||||
.format(package.sha256, digest))
|
||||
sys.exit(-1)
|
||||
|
||||
# Write the file as the bazel file name.
|
||||
out_file = open(package.file, "wb")
|
||||
out_file.write(contents)
|
||||
out_file.close()
|
||||
|
||||
# Mark the file as executable.
|
||||
st = os.stat(package.file)
|
||||
os.chmod(package.file,
|
||||
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
||||
|
||||
return "./" + package.file
|
||||
|
||||
|
||||
def get_bazel_path(bazel_path_flag):
|
||||
"""Returns the path to a Bazel binary, downloading Bazel if not found."""
|
||||
if bazel_path_flag:
|
||||
return bazel_path_flag
|
||||
|
||||
bazel = which("bazel")
|
||||
if bazel:
|
||||
return bazel
|
||||
|
||||
bazel = download_and_verify_bazel()
|
||||
if bazel:
|
||||
return bazel
|
||||
|
||||
print("Cannot find or download bazel. Please install bazel.")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def check_bazel_version(bazel_path, min_version):
|
||||
"""Checks Bazel's version is at least `min_version`."""
|
||||
version_output = shell([bazel_path, "--bazelrc=/dev/null", "version"])
|
||||
match = re.search("Build label: *([0-9\\.]+)[^0-9\\.]", version_output)
|
||||
if match is None:
|
||||
print("Warning: bazel installation is not a release version. Make sure "
|
||||
"bazel is at least 0.19.2")
|
||||
return
|
||||
version = match.group(1)
|
||||
min_ints = [int(x) for x in min_version.split(".")]
|
||||
actual_ints = [int(x) for x in match.group(1).split(".")]
|
||||
if min_ints > actual_ints:
|
||||
print("Outdated bazel revision (>= {} required, found {})".format(
|
||||
min_version, version))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
BAZELRC_TEMPLATE = """
|
||||
build --action_env PYTHON_BIN_PATH="{python_bin_path}"
|
||||
build --python_path="{python_bin_path}"
|
||||
build --action_env TF_NEED_CUDA="{tf_need_cuda}"
|
||||
build --action_env CUDA_TOOLKIT_PATH="{cuda_toolkit_path}"
|
||||
build --action_env CUDNN_INSTALL_PATH="{cudnn_install_path}"
|
||||
build:opt --copt=-march=native
|
||||
build:opt --copt=-Wno-sign-compare
|
||||
build:opt --host_copt=-march=native
|
||||
"""
|
||||
|
||||
|
||||
def write_bazelrc(**kwargs):
|
||||
f = open(".bazelrc", "w")
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
f.close()
|
||||
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
| | / \ \ \/ /
|
||||
_ | |/ _ \ \ /
|
||||
| |_| / ___ \ / \
|
||||
\___/_/ \_\/_/\_\
|
||||
|
||||
"""
|
||||
|
||||
EPILOG = """
|
||||
|
||||
From the JAX repository root, run
|
||||
python build.py
|
||||
or
|
||||
python3 build.py
|
||||
|
||||
Downloads and builds JAX's XLA dependency, installing XLA in the JAX source
|
||||
tree.
|
||||
"""
|
||||
|
||||
|
||||
def _parse_string_as_bool(s):
|
||||
"""Parses a string as a boolean argument."""
|
||||
lower = s.lower()
|
||||
if lower == "true":
|
||||
return True
|
||||
elif lower == "false":
|
||||
return False
|
||||
else:
|
||||
raise ValueError("Expected either 'true' or 'false'; got {}".format(s))
|
||||
|
||||
|
||||
def add_boolean_argument(parser, name, default=False, help_str=None):
|
||||
"""Creates a boolean flag."""
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
"--" + name,
|
||||
nargs="?",
|
||||
default=default,
|
||||
const=True,
|
||||
type=_parse_string_as_bool,
|
||||
help=help_str)
|
||||
group.add_argument("--no" + name, dest=name, action="store_false")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Builds JAX from source.", epilog=EPILOG)
|
||||
parser.add_argument(
|
||||
"--bazel_path",
|
||||
help="Path to the Bazel binary to use. The default is to find bazel via "
|
||||
"the PATH; if none is found, downloads a fresh copy of bazel from "
|
||||
"GitHub.")
|
||||
parser.add_argument(
|
||||
"--python_bin_path",
|
||||
help="Path to Python binary to use. The default is the Python "
|
||||
"interpreter used to run the build script.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_cuda",
|
||||
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
|
||||
parser.add_argument(
|
||||
"--cuda_path",
|
||||
default="/usr/local/cuda",
|
||||
help="Path to the CUDA toolkit.")
|
||||
parser.add_argument(
|
||||
"--cudnn_path",
|
||||
default="/usr/local/cuda",
|
||||
help="Path to CUDNN libraries.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(BANNER)
|
||||
|
||||
# Find a working Bazel.
|
||||
bazel_path = get_bazel_path(args.bazel_path)
|
||||
check_bazel_version(bazel_path, "0.19.2")
|
||||
print("Bazel binary path: {}".format(bazel_path))
|
||||
|
||||
python_bin_path = get_python_bin_path(args.python_bin_path)
|
||||
print("Python binary path: {}".format(python_bin_path))
|
||||
|
||||
cuda_toolkit_path = args.cuda_path
|
||||
cudnn_install_path = args.cudnn_path
|
||||
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
|
||||
if args.enable_cuda:
|
||||
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
||||
print("CUDNN library path: {}".format(cudnn_install_path))
|
||||
write_bazelrc(
|
||||
python_bin_path=python_bin_path,
|
||||
tf_need_cuda=1 if args.enable_cuda else 0,
|
||||
cuda_toolkit_path=cuda_toolkit_path,
|
||||
cudnn_install_path=cudnn_install_path)
|
||||
|
||||
print("\nBuilding XLA and installing it in the JAX source tree...")
|
||||
shell([
|
||||
bazel_path, "run", "-c", "opt", "//build:install_xla_in_source_tree",
|
||||
os.getcwd()
|
||||
])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
28
build/BUILD
Normal file
28
build/BUILD
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
sh_binary(
|
||||
name = "install_xla_in_source_tree",
|
||||
srcs = ["install_xla_in_source_tree.sh"],
|
||||
data = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
],
|
||||
deps = ["@bazel_tools//tools/bash/runfiles"],
|
||||
)
|
67
build/install_xla_in_source_tree.sh
Executable file
67
build/install_xla_in_source_tree.sh
Executable file
@ -0,0 +1,67 @@
|
||||
#!/bin/sh
|
||||
#
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script that installs JAX's XLA dependencies inside the JAX source tree.
|
||||
|
||||
# --- begin runfiles.bash initialization ---
|
||||
# Copy-pasted from Bazel's Bash runfiles library (tools/bash/runfiles/runfiles.bash).
|
||||
set -euo pipefail
|
||||
if [[ ! -d "${RUNFILES_DIR:-/dev/null}" && ! -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
|
||||
if [[ -f "$0.runfiles_manifest" ]]; then
|
||||
export RUNFILES_MANIFEST_FILE="$0.runfiles_manifest"
|
||||
elif [[ -f "$0.runfiles/MANIFEST" ]]; then
|
||||
export RUNFILES_MANIFEST_FILE="$0.runfiles/MANIFEST"
|
||||
elif [[ -f "$0.runfiles/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
|
||||
export RUNFILES_DIR="$0.runfiles"
|
||||
fi
|
||||
fi
|
||||
if [[ -f "${RUNFILES_DIR:-/dev/null}/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
|
||||
source "${RUNFILES_DIR}/bazel_tools/tools/bash/runfiles/runfiles.bash"
|
||||
elif [[ -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
|
||||
source "$(grep -m1 "^bazel_tools/tools/bash/runfiles/runfiles.bash " \
|
||||
"$RUNFILES_MANIFEST_FILE" | cut -d ' ' -f 2-)"
|
||||
else
|
||||
echo >&2 "ERROR: cannot find @bazel_tools//tools/bash/runfiles:runfiles.bash"
|
||||
exit 1
|
||||
fi
|
||||
# --- end runfiles.bash initialization ---
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "Usage: $0 <target directory>"
|
||||
exit 1
|
||||
fi
|
||||
TARGET="$1"
|
||||
|
||||
if [[ ! -r "${TARGET}/jax/lax.py" ]]; then
|
||||
echo "Target directory ${TARGET} does not seem to be a JAX source tree" \
|
||||
"(missing jax/lax.py)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
|
||||
# new location.
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/xla_data_pb2.py)" \
|
||||
"${TARGET}/jax/lib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/pywrap_xla.py)" \
|
||||
"${TARGET}/jax/lib"
|
||||
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/_pywrap_xla.so)" \
|
||||
"${TARGET}/jax/lib"
|
||||
sed \
|
||||
-e 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' \
|
||||
-e 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' \
|
||||
-e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \
|
||||
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
|
||||
> "${TARGET}/jax/lib/xla_client.py"
|
@ -1,107 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -exv
|
||||
|
||||
# For a build with CUDA, from the repo root run:
|
||||
# bash build/build_jax.sh
|
||||
# For building without CUDA (CPU-only), instead run:
|
||||
# JAX_BUILD_WITH_CUDA=0 bash build/build_jax.sh
|
||||
# To clean intermediate results, run
|
||||
# rm -rf /tmp/jax-build/jax-bazel-output-user-root
|
||||
# To clean everything, run
|
||||
# rm -rf /tmp/jax-build
|
||||
|
||||
JAX_BUILD_WITH_CUDA=${JAX_BUILD_WITH_CUDA:-1}
|
||||
|
||||
init_commit=a30e858e59d7184b9e54dc3f3955238221d70439
|
||||
if [[ ! -d .git || $(git rev-list --parents HEAD | tail -1) != ${init_commit} ]]
|
||||
then
|
||||
(>&2 echo "must be executed from jax repo root")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
tmp=/tmp/jax-build # could mktemp -d but this way we can cache results
|
||||
mkdir -p ${tmp}
|
||||
|
||||
## get bazel
|
||||
bazel_dir=${tmp}/jax-bazel
|
||||
if [ ! -d ${bazel_dir}/bin ]
|
||||
then
|
||||
mkdir -p ${bazel_dir}
|
||||
case "$(uname -s)" in
|
||||
Linux*) installer=bazel-0.19.2-installer-linux-x86_64.sh;;
|
||||
Darwin*) installer=bazel-0.19.2-installer-darwin-x86_64.sh;;
|
||||
*) exit 1;;
|
||||
esac
|
||||
curl -OL https://github.com/bazelbuild/bazel/releases/download/0.19.2/${installer}
|
||||
chmod +x ${installer}
|
||||
bash ${installer} --prefix=${bazel_dir}
|
||||
rm ${installer}
|
||||
fi
|
||||
export PATH="${bazel_dir}/bin:$PATH"
|
||||
|
||||
## get and configure tensorflow for building xla
|
||||
if [[ ! -d tensorflow ]]
|
||||
then
|
||||
git clone https://github.com/tensorflow/tensorflow.git
|
||||
fi
|
||||
pushd tensorflow
|
||||
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-$(which python)}
|
||||
export PYTHON_LIB_PATH=${SP_DIR:-$(python -m site --user-site)}
|
||||
export USE_DEFAULT_PYTHON_LIB_PATH=1
|
||||
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
|
||||
then
|
||||
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
|
||||
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
|
||||
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
|
||||
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
|
||||
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
|
||||
export TF_NCCL_VERSION=2
|
||||
export TF_NEED_CUDA=1
|
||||
else
|
||||
export TF_NEED_CUDA=0
|
||||
fi
|
||||
export GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
|
||||
export TF_ENABLE_XLA=1
|
||||
export TF_NEED_MKL=0
|
||||
export CC_OPT_FLAGS="-march=native -Wno-sign-compare"
|
||||
export TF_NEED_IGNITE=1
|
||||
export TF_NEED_OPENCL=0
|
||||
export TF_NEED_OPENCL_SYCL=0
|
||||
export TF_NEED_ROCM=0
|
||||
export TF_NEED_MPI=0
|
||||
export TF_DOWNLOAD_CLANG=0
|
||||
export TF_SET_ANDROID_WORKSPACE=0
|
||||
export TF_CUDA_CLANG=0
|
||||
export TF_NEED_TENSORRT=0
|
||||
./configure
|
||||
popd
|
||||
|
||||
## build xla inside tensorflow
|
||||
mkdir -p ${PYTHON_LIB_PATH}
|
||||
bazel_output_user_root=${tmp}/jax-bazel-output-user-root
|
||||
bazel_output_base=${bazel_output_user_root}/output-base
|
||||
bazel_opt="--output_user_root=${bazel_output_user_root} --output_base=${bazel_output_base} --bazelrc=tensorflow/tools/bazel.rc"
|
||||
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
|
||||
then
|
||||
bazel_build_opt="-c opt --config=cuda"
|
||||
else
|
||||
bazel_build_opt="-c opt"
|
||||
fi
|
||||
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax
|
||||
|
||||
## extract the pieces we need
|
||||
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/jax/build_jax.runfiles/org_tensorflow/tensorflow"
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
|
||||
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
|
||||
|
||||
## rewrite some imports
|
||||
sed -i 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' jax/lib/xla_client.py
|
||||
sed -i 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' jax/lib/xla_client.py
|
||||
sed -i '/from tensorflow.compiler.xla.service import hlo_pb2/d' jax/lib/xla_client.py
|
||||
|
||||
## clean up
|
||||
rm -f bazel-* # symlinks
|
||||
rm -rf tensorflow
|
||||
rm -rf ${bazel_output_user_root} # clean build results
|
||||
# rm -rf ${tmp} # clean everything, including the bazel binary
|
Loading…
x
Reference in New Issue
Block a user