[JAX] Rewrite OSS build script.

Significant changes:
* Mac OS X support.
* build script is in Python, not shell.
* build configuration is passed via flags, not environment variables.
* build script configures TF itself, and does not require explicitly checking out the TF git repository and running its configure script. Changes the TF dependency in the Bazel workspace to be an http_archive(), rather than a local checkout of TF.
* rather than trying to guess the path for Bazel-generated XLA artifacts, use a sh_binary() to perform installation of the built artifacts in to the JAX source tree. Bazel's runfiles mechanism is the supported route to find build artifacts.
* downloads Bazel in Python and checks its SHA256 before running it, rather than running an untrusted binary from the internet.
* intentionally does not delete the Bazel cache or Bazel after building.

Example of new build interaction:

Building without CUDA on Mac or Linux:
$ cd jax
$ python3 build.py   (or python2 build.py if you want a Python 2 build)

     _   _    __  __
    | | / \   \ \/ /
 _  | |/ _ \   \  /
| |_| / ___ \  /  \
 \___/_/   \_\/_/\_\

Starting local Bazel server and connecting to it...
Bazel binary path: /Users/xyz/bin/bazel
Python binary path: /Library/Frameworks/Python.framework/Versions/3.7/bin/python3
CUDA enabled: no

Building XLA and installing it in the JAX source tree...
...

Example of building with CUDA enabled on Linux:
$ python3 build.py --enable_cuda --cudnn_path=/usr/lib/x86_64-linux-gnu/
... as before, except ...
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda
CUDNN library path: /usr/lib/x86_64-linux-gnu/
...

PiperOrigin-RevId: 222868835
This commit is contained in:
Peter Hawkins 2018-11-26 12:37:24 -08:00 committed by Roy Frostig
parent 326773808b
commit f3513a7bfb
5 changed files with 396 additions and 112 deletions

View File

@ -1,8 +1,3 @@
local_repository(
name = "org_tensorflow",
path = "tensorflow",
)
http_archive(
name = "io_bazel_rules_closure",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
@ -13,6 +8,28 @@ http_archive(
],
)
# To update TensorFlow to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "599e9aad221a27882fce98ff472372030a2ebfe63cfc643d3470691f34bb68d6",
strip_prefix="tensorflow-64e084b8cb27e8c53b15468c21f1b3471b4b9659",
urls = [
"https://github.com/tensorflow/tensorflow/archive/64e084b8cb27e8c53b15468c21f1b3471b4b9659.tar.gz",
],
)
# For development, one can use a local TF repository instead.
# local_repository(
# name = "org_tensorflow",
# path = "tensorflow",
# )
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace(

279
build.py Normal file
View File

@ -0,0 +1,279 @@
#!/usr/bin/python
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Helper script for building JAX easily.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import hashlib
import os
import platform
import re
import shutil
import stat
import subprocess
import sys
import urllib
# pylint: disable=g-import-not-at-top
if hasattr(urllib, "urlretrieve"):
urlretrieve = urllib.urlretrieve
else:
import urllib.request
urlretrieve = urllib.request.urlretrieve
if hasattr(shutil, "which"):
which = shutil.which
else:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
def shell(cmd):
output = subprocess.check_output(cmd)
return output.decode("UTF-8").strip()
# Python
def get_python_bin_path(python_bin_path_flag):
"""Returns the path to the Python interpreter to use."""
return python_bin_path_flag or sys.executable
# Bazel
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/0.19.2/"
BazelPackage = collections.namedtuple("BazelPackage", ["file", "sha256"])
bazel_packages = {
"Linux":
BazelPackage(
file="bazel-0.19.2-linux-x86_64",
sha256=
"2ee9f23b49fb47725f725579c47f4f50272f4f9d23643e32add1fdef6aa0c5e0"),
"Darwin":
BazelPackage(
file="bazel-0.19.2-darwin-x86_64",
sha256=
"74ae65127b46b59305fc5ea0c6baca355fce7e87c8624448e06f8cf2366b507e"),
}
def download_and_verify_bazel():
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
package = bazel_packages.get(platform.system())
if package is None:
return None
if not os.access(package.file, os.X_OK):
uri = BAZEL_BASE_URI + package.file
sys.stdout.write("Downloading bazel from: {}\n".format(uri))
def progress(block_count, block_size, total_size):
if total_size <= 0:
total_size = 170**6
progress = (block_count * block_size) / total_size
num_chars = 40
progress_chars = int(num_chars * progress)
sys.stdout.write("{} [{}{}] {}%\r".format(
package.file, "#" * progress_chars,
"." * (num_chars - progress_chars), int(progress * 100.0)))
tmp_path, _ = urlretrieve(uri, None, progress)
sys.stdout.write("\n")
# Verify that the downloaded Bazel binary has the expected SHA256.
downloaded_file = open(tmp_path, "rb")
contents = downloaded_file.read()
downloaded_file.close()
digest = hashlib.sha256(contents).hexdigest()
if digest != package.sha256:
print(
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
.format(package.sha256, digest))
sys.exit(-1)
# Write the file as the bazel file name.
out_file = open(package.file, "wb")
out_file.write(contents)
out_file.close()
# Mark the file as executable.
st = os.stat(package.file)
os.chmod(package.file,
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
return "./" + package.file
def get_bazel_path(bazel_path_flag):
"""Returns the path to a Bazel binary, downloading Bazel if not found."""
if bazel_path_flag:
return bazel_path_flag
bazel = which("bazel")
if bazel:
return bazel
bazel = download_and_verify_bazel()
if bazel:
return bazel
print("Cannot find or download bazel. Please install bazel.")
sys.exit(-1)
def check_bazel_version(bazel_path, min_version):
"""Checks Bazel's version is at least `min_version`."""
version_output = shell([bazel_path, "--bazelrc=/dev/null", "version"])
match = re.search("Build label: *([0-9\\.]+)[^0-9\\.]", version_output)
if match is None:
print("Warning: bazel installation is not a release version. Make sure "
"bazel is at least 0.19.2")
return
version = match.group(1)
min_ints = [int(x) for x in min_version.split(".")]
actual_ints = [int(x) for x in match.group(1).split(".")]
if min_ints > actual_ints:
print("Outdated bazel revision (>= {} required, found {})".format(
min_version, version))
sys.exit(0)
BAZELRC_TEMPLATE = """
build --action_env PYTHON_BIN_PATH="{python_bin_path}"
build --python_path="{python_bin_path}"
build --action_env TF_NEED_CUDA="{tf_need_cuda}"
build --action_env CUDA_TOOLKIT_PATH="{cuda_toolkit_path}"
build --action_env CUDNN_INSTALL_PATH="{cudnn_install_path}"
build:opt --copt=-march=native
build:opt --copt=-Wno-sign-compare
build:opt --host_copt=-march=native
"""
def write_bazelrc(**kwargs):
f = open(".bazelrc", "w")
f.write(BAZELRC_TEMPLATE.format(**kwargs))
f.close()
BANNER = r"""
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \ / \
\___/_/ \_\/_/\_\
"""
EPILOG = """
From the JAX repository root, run
python build.py
or
python3 build.py
Downloads and builds JAX's XLA dependency, installing XLA in the JAX source
tree.
"""
def _parse_string_as_bool(s):
"""Parses a string as a boolean argument."""
lower = s.lower()
if lower == "true":
return True
elif lower == "false":
return False
else:
raise ValueError("Expected either 'true' or 'false'; got {}".format(s))
def add_boolean_argument(parser, name, default=False, help_str=None):
"""Creates a boolean flag."""
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--" + name,
nargs="?",
default=default,
const=True,
type=_parse_string_as_bool,
help=help_str)
group.add_argument("--no" + name, dest=name, action="store_false")
def main():
parser = argparse.ArgumentParser(
description="Builds JAX from source.", epilog=EPILOG)
parser.add_argument(
"--bazel_path",
help="Path to the Bazel binary to use. The default is to find bazel via "
"the PATH; if none is found, downloads a fresh copy of bazel from "
"GitHub.")
parser.add_argument(
"--python_bin_path",
help="Path to Python binary to use. The default is the Python "
"interpreter used to run the build script.")
add_boolean_argument(
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
parser.add_argument(
"--cuda_path",
default="/usr/local/cuda",
help="Path to the CUDA toolkit.")
parser.add_argument(
"--cudnn_path",
default="/usr/local/cuda",
help="Path to CUDNN libraries.")
args = parser.parse_args()
print(BANNER)
# Find a working Bazel.
bazel_path = get_bazel_path(args.bazel_path)
check_bazel_version(bazel_path, "0.19.2")
print("Bazel binary path: {}".format(bazel_path))
python_bin_path = get_python_bin_path(args.python_bin_path)
print("Python binary path: {}".format(python_bin_path))
cuda_toolkit_path = args.cuda_path
cudnn_install_path = args.cudnn_path
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
if args.enable_cuda:
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
print("CUDNN library path: {}".format(cudnn_install_path))
write_bazelrc(
python_bin_path=python_bin_path,
tf_need_cuda=1 if args.enable_cuda else 0,
cuda_toolkit_path=cuda_toolkit_path,
cudnn_install_path=cudnn_install_path)
print("\nBuilding XLA and installing it in the JAX source tree...")
shell([
bazel_path, "run", "-c", "opt", "//build:install_xla_in_source_tree",
os.getcwd()
])
if __name__ == "__main__":
main()

28
build/BUILD Normal file
View File

@ -0,0 +1,28 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX is Autograd and XLA
licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
sh_binary(
name = "install_xla_in_source_tree",
srcs = ["install_xla_in_source_tree.sh"],
data = [
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
],
deps = ["@bazel_tools//tools/bash/runfiles"],
)

View File

@ -0,0 +1,67 @@
#!/bin/sh
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that installs JAX's XLA dependencies inside the JAX source tree.
# --- begin runfiles.bash initialization ---
# Copy-pasted from Bazel's Bash runfiles library (tools/bash/runfiles/runfiles.bash).
set -euo pipefail
if [[ ! -d "${RUNFILES_DIR:-/dev/null}" && ! -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
if [[ -f "$0.runfiles_manifest" ]]; then
export RUNFILES_MANIFEST_FILE="$0.runfiles_manifest"
elif [[ -f "$0.runfiles/MANIFEST" ]]; then
export RUNFILES_MANIFEST_FILE="$0.runfiles/MANIFEST"
elif [[ -f "$0.runfiles/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
export RUNFILES_DIR="$0.runfiles"
fi
fi
if [[ -f "${RUNFILES_DIR:-/dev/null}/bazel_tools/tools/bash/runfiles/runfiles.bash" ]]; then
source "${RUNFILES_DIR}/bazel_tools/tools/bash/runfiles/runfiles.bash"
elif [[ -f "${RUNFILES_MANIFEST_FILE:-/dev/null}" ]]; then
source "$(grep -m1 "^bazel_tools/tools/bash/runfiles/runfiles.bash " \
"$RUNFILES_MANIFEST_FILE" | cut -d ' ' -f 2-)"
else
echo >&2 "ERROR: cannot find @bazel_tools//tools/bash/runfiles:runfiles.bash"
exit 1
fi
# --- end runfiles.bash initialization ---
if [[ $# -ne 1 ]]; then
echo "Usage: $0 <target directory>"
exit 1
fi
TARGET="$1"
if [[ ! -r "${TARGET}/jax/lax.py" ]]; then
echo "Target directory ${TARGET} does not seem to be a JAX source tree" \
"(missing jax/lax.py)"
exit 1
fi
# Copy the XLA dependencies into jax/lib, fixing up some imports to point to the
# new location.
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/xla_data_pb2.py)" \
"${TARGET}/jax/lib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/pywrap_xla.py)" \
"${TARGET}/jax/lib"
cp -f "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/_pywrap_xla.so)" \
"${TARGET}/jax/lib"
sed \
-e 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' \
-e 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' \
-e '/from tensorflow.compiler.xla.service import hlo_pb2/d' \
< "$(rlocation org_tensorflow/tensorflow/compiler/xla/python/xla_client.py)" \
> "${TARGET}/jax/lib/xla_client.py"

View File

@ -1,107 +0,0 @@
#!/bin/bash
set -exv
# For a build with CUDA, from the repo root run:
# bash build/build_jax.sh
# For building without CUDA (CPU-only), instead run:
# JAX_BUILD_WITH_CUDA=0 bash build/build_jax.sh
# To clean intermediate results, run
# rm -rf /tmp/jax-build/jax-bazel-output-user-root
# To clean everything, run
# rm -rf /tmp/jax-build
JAX_BUILD_WITH_CUDA=${JAX_BUILD_WITH_CUDA:-1}
init_commit=a30e858e59d7184b9e54dc3f3955238221d70439
if [[ ! -d .git || $(git rev-list --parents HEAD | tail -1) != ${init_commit} ]]
then
(>&2 echo "must be executed from jax repo root")
exit 1
fi
tmp=/tmp/jax-build # could mktemp -d but this way we can cache results
mkdir -p ${tmp}
## get bazel
bazel_dir=${tmp}/jax-bazel
if [ ! -d ${bazel_dir}/bin ]
then
mkdir -p ${bazel_dir}
case "$(uname -s)" in
Linux*) installer=bazel-0.19.2-installer-linux-x86_64.sh;;
Darwin*) installer=bazel-0.19.2-installer-darwin-x86_64.sh;;
*) exit 1;;
esac
curl -OL https://github.com/bazelbuild/bazel/releases/download/0.19.2/${installer}
chmod +x ${installer}
bash ${installer} --prefix=${bazel_dir}
rm ${installer}
fi
export PATH="${bazel_dir}/bin:$PATH"
## get and configure tensorflow for building xla
if [[ ! -d tensorflow ]]
then
git clone https://github.com/tensorflow/tensorflow.git
fi
pushd tensorflow
export PYTHON_BIN_PATH=${PYTHON_BIN_PATH:-$(which python)}
export PYTHON_LIB_PATH=${SP_DIR:-$(python -m site --user-site)}
export USE_DEFAULT_PYTHON_LIB_PATH=1
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
then
export CUDA_TOOLKIT_PATH=${CUDA_PATH:-/usr/local/cuda}
export CUDNN_INSTALL_PATH=${CUDA_TOOLKIT_PATH}
export TF_CUDA_VERSION=$(readlink -f ${CUDA_TOOLKIT_PATH}/lib64/libcudart.so | cut -d '.' -f4-5)
export TF_CUDNN_VERSION=$(readlink -f ${CUDNN_INSTALL_PATH}/lib64/libcudnn.so | cut -d '.' -f4-5)
export TF_CUDA_COMPUTE_CAPABILITIES="3.0,3.5,5.2,6.0,6.1,7.0"
export TF_NCCL_VERSION=2
export TF_NEED_CUDA=1
else
export TF_NEED_CUDA=0
fi
export GCC_HOST_COMPILER_PATH="/usr/bin/gcc"
export TF_ENABLE_XLA=1
export TF_NEED_MKL=0
export CC_OPT_FLAGS="-march=native -Wno-sign-compare"
export TF_NEED_IGNITE=1
export TF_NEED_OPENCL=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_ROCM=0
export TF_NEED_MPI=0
export TF_DOWNLOAD_CLANG=0
export TF_SET_ANDROID_WORKSPACE=0
export TF_CUDA_CLANG=0
export TF_NEED_TENSORRT=0
./configure
popd
## build xla inside tensorflow
mkdir -p ${PYTHON_LIB_PATH}
bazel_output_user_root=${tmp}/jax-bazel-output-user-root
bazel_output_base=${bazel_output_user_root}/output-base
bazel_opt="--output_user_root=${bazel_output_user_root} --output_base=${bazel_output_base} --bazelrc=tensorflow/tools/bazel.rc"
if [[ ${JAX_BUILD_WITH_CUDA} != 0 ]]
then
bazel_build_opt="-c opt --config=cuda"
else
bazel_build_opt="-c opt"
fi
bazel ${bazel_opt} build ${bazel_build_opt} jax:build_jax
## extract the pieces we need
runfiles_prefix="execroot/__main__/bazel-out/k8-opt/bin/jax/build_jax.runfiles/org_tensorflow/tensorflow"
cp -f ${bazel_output_base}/${runfiles_prefix}/libtensorflow_framework.so jax/lib/
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/xla_data_pb2.py jax/lib/
cp -f ${bazel_output_base}/${runfiles_prefix}/compiler/xla/python/{xla_client.py,pywrap_xla.py,_pywrap_xla.so} jax/lib/
## rewrite some imports
sed -i 's/from tensorflow.compiler.xla.python import pywrap_xla as c_api/from . import pywrap_xla as c_api/' jax/lib/xla_client.py
sed -i 's/from tensorflow.compiler.xla import xla_data_pb2/from . import xla_data_pb2/' jax/lib/xla_client.py
sed -i '/from tensorflow.compiler.xla.service import hlo_pb2/d' jax/lib/xla_client.py
## clean up
rm -f bazel-* # symlinks
rm -rf tensorflow
rm -rf ${bazel_output_user_root} # clean build results
# rm -rf ${tmp} # clean everything, including the bazel binary