2018-11-26 12:37:24 -08:00
|
|
|
#!/usr/bin/python
|
|
|
|
#
|
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
2018-12-06 21:35:03 -05:00
|
|
|
# Helper script for building JAX's libjax easily.
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import collections
|
|
|
|
import hashlib
|
|
|
|
import os
|
|
|
|
import platform
|
|
|
|
import re
|
|
|
|
import shutil
|
|
|
|
import stat
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import urllib
|
|
|
|
|
|
|
|
# pylint: disable=g-import-not-at-top
|
|
|
|
if hasattr(urllib, "urlretrieve"):
|
|
|
|
urlretrieve = urllib.urlretrieve
|
|
|
|
else:
|
|
|
|
import urllib.request
|
|
|
|
urlretrieve = urllib.request.urlretrieve
|
|
|
|
|
|
|
|
if hasattr(shutil, "which"):
|
|
|
|
which = shutil.which
|
|
|
|
else:
|
|
|
|
from distutils.spawn import find_executable as which
|
|
|
|
# pylint: enable=g-import-not-at-top
|
|
|
|
|
|
|
|
|
|
|
|
def shell(cmd):
|
|
|
|
output = subprocess.check_output(cmd)
|
|
|
|
return output.decode("UTF-8").strip()
|
|
|
|
|
|
|
|
|
|
|
|
# Python
|
|
|
|
|
|
|
|
def get_python_bin_path(python_bin_path_flag):
|
|
|
|
"""Returns the path to the Python interpreter to use."""
|
|
|
|
return python_bin_path_flag or sys.executable
|
|
|
|
|
|
|
|
|
|
|
|
# Bazel
|
|
|
|
|
2019-03-31 10:31:47 -07:00
|
|
|
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/0.24.0/"
|
2018-11-26 12:37:24 -08:00
|
|
|
BazelPackage = collections.namedtuple("BazelPackage", ["file", "sha256"])
|
|
|
|
bazel_packages = {
|
|
|
|
"Linux":
|
|
|
|
BazelPackage(
|
2019-03-31 10:31:47 -07:00
|
|
|
file="bazel-0.24.0-linux-x86_64",
|
2018-11-26 12:37:24 -08:00
|
|
|
sha256=
|
2019-03-31 10:31:47 -07:00
|
|
|
"cf78da6f1b65e9e35f485eab421756c4b5188a705695276843759f3c3586bb0c"),
|
2018-11-26 12:37:24 -08:00
|
|
|
"Darwin":
|
|
|
|
BazelPackage(
|
2019-03-31 10:31:47 -07:00
|
|
|
file="bazel-0.24.0-darwin-x86_64",
|
2018-11-26 12:37:24 -08:00
|
|
|
sha256=
|
2019-03-31 10:31:47 -07:00
|
|
|
"adaacec710cae5a217dd967766fe489b8034aa9c0cb44d4eb06813d224489e01"),
|
2018-11-26 12:37:24 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def download_and_verify_bazel():
|
|
|
|
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
|
|
|
|
package = bazel_packages.get(platform.system())
|
|
|
|
if package is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
if not os.access(package.file, os.X_OK):
|
|
|
|
uri = BAZEL_BASE_URI + package.file
|
|
|
|
sys.stdout.write("Downloading bazel from: {}\n".format(uri))
|
|
|
|
|
|
|
|
def progress(block_count, block_size, total_size):
|
|
|
|
if total_size <= 0:
|
|
|
|
total_size = 170**6
|
|
|
|
progress = (block_count * block_size) / total_size
|
|
|
|
num_chars = 40
|
|
|
|
progress_chars = int(num_chars * progress)
|
|
|
|
sys.stdout.write("{} [{}{}] {}%\r".format(
|
|
|
|
package.file, "#" * progress_chars,
|
|
|
|
"." * (num_chars - progress_chars), int(progress * 100.0)))
|
|
|
|
|
|
|
|
tmp_path, _ = urlretrieve(uri, None, progress)
|
|
|
|
sys.stdout.write("\n")
|
|
|
|
|
|
|
|
# Verify that the downloaded Bazel binary has the expected SHA256.
|
|
|
|
downloaded_file = open(tmp_path, "rb")
|
|
|
|
contents = downloaded_file.read()
|
|
|
|
downloaded_file.close()
|
|
|
|
digest = hashlib.sha256(contents).hexdigest()
|
|
|
|
if digest != package.sha256:
|
|
|
|
print(
|
|
|
|
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
|
|
|
.format(package.sha256, digest))
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
# Write the file as the bazel file name.
|
|
|
|
out_file = open(package.file, "wb")
|
|
|
|
out_file.write(contents)
|
|
|
|
out_file.close()
|
|
|
|
|
|
|
|
# Mark the file as executable.
|
|
|
|
st = os.stat(package.file)
|
|
|
|
os.chmod(package.file,
|
|
|
|
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
|
|
|
|
|
|
|
return "./" + package.file
|
|
|
|
|
|
|
|
|
|
|
|
def get_bazel_path(bazel_path_flag):
|
|
|
|
"""Returns the path to a Bazel binary, downloading Bazel if not found."""
|
|
|
|
if bazel_path_flag:
|
|
|
|
return bazel_path_flag
|
|
|
|
|
|
|
|
bazel = which("bazel")
|
|
|
|
if bazel:
|
|
|
|
return bazel
|
|
|
|
|
|
|
|
bazel = download_and_verify_bazel()
|
|
|
|
if bazel:
|
|
|
|
return bazel
|
|
|
|
|
|
|
|
print("Cannot find or download bazel. Please install bazel.")
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
|
|
def check_bazel_version(bazel_path, min_version):
|
|
|
|
"""Checks Bazel's version is at least `min_version`."""
|
|
|
|
version_output = shell([bazel_path, "--bazelrc=/dev/null", "version"])
|
|
|
|
match = re.search("Build label: *([0-9\\.]+)[^0-9\\.]", version_output)
|
|
|
|
if match is None:
|
|
|
|
print("Warning: bazel installation is not a release version. Make sure "
|
2019-03-31 10:31:47 -07:00
|
|
|
"bazel is at least {}".format(min_version))
|
2018-11-26 12:37:24 -08:00
|
|
|
return
|
|
|
|
version = match.group(1)
|
|
|
|
min_ints = [int(x) for x in min_version.split(".")]
|
|
|
|
actual_ints = [int(x) for x in match.group(1).split(".")]
|
|
|
|
if min_ints > actual_ints:
|
|
|
|
print("Outdated bazel revision (>= {} required, found {})".format(
|
|
|
|
min_version, version))
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
BAZELRC_TEMPLATE = """
|
|
|
|
build --action_env PYTHON_BIN_PATH="{python_bin_path}"
|
|
|
|
build --python_path="{python_bin_path}"
|
|
|
|
build --action_env TF_NEED_CUDA="{tf_need_cuda}"
|
|
|
|
build --action_env CUDA_TOOLKIT_PATH="{cuda_toolkit_path}"
|
|
|
|
build --action_env CUDNN_INSTALL_PATH="{cudnn_install_path}"
|
2018-12-10 09:18:49 -05:00
|
|
|
build --distinct_host_configuration=false
|
2018-12-14 12:21:25 -05:00
|
|
|
build --copt=-Wno-sign-compare
|
|
|
|
build -c opt
|
2018-11-26 12:37:24 -08:00
|
|
|
build:opt --copt=-march=native
|
|
|
|
build:opt --host_copt=-march=native
|
2018-12-14 12:21:25 -05:00
|
|
|
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
2019-01-07 10:59:08 -05:00
|
|
|
|
2019-03-26 22:26:12 -04:00
|
|
|
# Sets the default Apple platform to macOS.
|
|
|
|
build --apple_platform_type=macos
|
|
|
|
|
2019-01-07 10:59:08 -05:00
|
|
|
# Disable enabled-by-default TensorFlow features that we don't care about.
|
|
|
|
build --define=no_aws_support=true
|
|
|
|
build --define=no_gcp_support=true
|
|
|
|
build --define=no_hdfs_support=true
|
|
|
|
build --define=no_kafka_support=true
|
|
|
|
build --define=no_ignite_support=true
|
2019-03-31 10:31:47 -07:00
|
|
|
build --define=grpc_no_ares=true
|
2019-03-06 15:20:43 -08:00
|
|
|
|
|
|
|
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
|
|
|
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
|
2018-11-26 12:37:24 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def write_bazelrc(**kwargs):
|
2018-12-12 15:55:55 -05:00
|
|
|
f = open("../.bazelrc", "w")
|
2018-11-26 12:37:24 -08:00
|
|
|
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
|
|
BANNER = r"""
|
2018-12-05 08:22:27 -08:00
|
|
|
_ _ __ __
|
|
|
|
| | / \ \ \/ /
|
|
|
|
_ | |/ _ \ \ /
|
|
|
|
| |_| / ___ \/ \
|
|
|
|
\___/_/ \/_/\_\
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
EPILOG = """
|
|
|
|
|
2018-12-06 21:35:03 -05:00
|
|
|
From the 'build' directory in the JAX repository, run
|
|
|
|
python build.py
|
2018-11-26 12:37:24 -08:00
|
|
|
or
|
2018-12-06 21:35:03 -05:00
|
|
|
python3 build.py
|
|
|
|
to download and build JAX's XLA (jaxlib) dependency.
|
2018-11-26 12:37:24 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_string_as_bool(s):
|
|
|
|
"""Parses a string as a boolean argument."""
|
|
|
|
lower = s.lower()
|
|
|
|
if lower == "true":
|
|
|
|
return True
|
|
|
|
elif lower == "false":
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
raise ValueError("Expected either 'true' or 'false'; got {}".format(s))
|
|
|
|
|
|
|
|
|
|
|
|
def add_boolean_argument(parser, name, default=False, help_str=None):
|
|
|
|
"""Creates a boolean flag."""
|
|
|
|
group = parser.add_mutually_exclusive_group()
|
|
|
|
group.add_argument(
|
|
|
|
"--" + name,
|
|
|
|
nargs="?",
|
|
|
|
default=default,
|
|
|
|
const=True,
|
|
|
|
type=_parse_string_as_bool,
|
|
|
|
help=help_str)
|
|
|
|
group.add_argument("--no" + name, dest=name, action="store_false")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser(
|
2018-12-06 21:35:03 -05:00
|
|
|
description="Builds libjax from source.", epilog=EPILOG)
|
2018-11-26 12:37:24 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--bazel_path",
|
|
|
|
help="Path to the Bazel binary to use. The default is to find bazel via "
|
|
|
|
"the PATH; if none is found, downloads a fresh copy of bazel from "
|
|
|
|
"GitHub.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--python_bin_path",
|
|
|
|
help="Path to Python binary to use. The default is the Python "
|
|
|
|
"interpreter used to run the build script.")
|
2018-12-14 12:21:25 -05:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_march_native",
|
|
|
|
default=False,
|
|
|
|
help_str="Generate code targeted to the current machine? This may "
|
|
|
|
"increase performance, but may generate code that does not run on "
|
|
|
|
"older machines.")
|
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_mkl_dnn",
|
|
|
|
default=True,
|
|
|
|
help_str="Should we build with MKL-DNN enabled?")
|
2018-11-26 12:37:24 -08:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_cuda",
|
|
|
|
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--cuda_path",
|
|
|
|
default="/usr/local/cuda",
|
|
|
|
help="Path to the CUDA toolkit.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--cudnn_path",
|
|
|
|
default="/usr/local/cuda",
|
|
|
|
help="Path to CUDNN libraries.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
print(BANNER)
|
2018-12-07 04:59:03 -08:00
|
|
|
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
# Find a working Bazel.
|
|
|
|
bazel_path = get_bazel_path(args.bazel_path)
|
2019-03-31 10:31:47 -07:00
|
|
|
check_bazel_version(bazel_path, "0.24.0")
|
2018-11-26 12:37:24 -08:00
|
|
|
print("Bazel binary path: {}".format(bazel_path))
|
|
|
|
|
|
|
|
python_bin_path = get_python_bin_path(args.python_bin_path)
|
|
|
|
print("Python binary path: {}".format(python_bin_path))
|
|
|
|
|
2018-12-14 12:21:25 -05:00
|
|
|
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
|
|
|
print("-march=native: {}".format("yes" if args.enable_march_native else "no"))
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
cuda_toolkit_path = args.cuda_path
|
|
|
|
cudnn_install_path = args.cudnn_path
|
|
|
|
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
|
|
|
|
if args.enable_cuda:
|
|
|
|
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
|
|
|
print("CUDNN library path: {}".format(cudnn_install_path))
|
|
|
|
write_bazelrc(
|
|
|
|
python_bin_path=python_bin_path,
|
|
|
|
tf_need_cuda=1 if args.enable_cuda else 0,
|
|
|
|
cuda_toolkit_path=cuda_toolkit_path,
|
|
|
|
cudnn_install_path=cudnn_install_path)
|
|
|
|
|
2018-12-06 21:35:03 -05:00
|
|
|
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
2018-12-14 12:21:25 -05:00
|
|
|
config_args = []
|
|
|
|
if args.enable_march_native:
|
|
|
|
config_args += ["--config=opt"]
|
|
|
|
if args.enable_mkl_dnn:
|
|
|
|
config_args += ["--config=mkl_open_source_only"]
|
2019-03-06 15:20:43 -08:00
|
|
|
if args.enable_cuda:
|
|
|
|
config_args += ["--config=cuda"]
|
2018-12-14 12:21:25 -05:00
|
|
|
shell(
|
|
|
|
[bazel_path, "run", "--verbose_failures=true"] +
|
|
|
|
config_args +
|
|
|
|
[":install_xla_in_source_tree", os.getcwd()])
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|