2018-11-26 12:37:24 -08:00
|
|
|
#!/usr/bin/python
|
|
|
|
#
|
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
2018-12-06 21:35:03 -05:00
|
|
|
# Helper script for building JAX's libjax easily.
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import collections
|
|
|
|
import hashlib
|
|
|
|
import os
|
|
|
|
import platform
|
|
|
|
import re
|
|
|
|
import shutil
|
|
|
|
import stat
|
|
|
|
import subprocess
|
|
|
|
import sys
|
|
|
|
import urllib
|
|
|
|
|
|
|
|
# pylint: disable=g-import-not-at-top
|
|
|
|
if hasattr(urllib, "urlretrieve"):
|
|
|
|
urlretrieve = urllib.urlretrieve
|
|
|
|
else:
|
|
|
|
import urllib.request
|
|
|
|
urlretrieve = urllib.request.urlretrieve
|
|
|
|
|
|
|
|
if hasattr(shutil, "which"):
|
|
|
|
which = shutil.which
|
|
|
|
else:
|
|
|
|
from distutils.spawn import find_executable as which
|
|
|
|
# pylint: enable=g-import-not-at-top
|
|
|
|
|
|
|
|
|
2020-11-10 00:23:54 +08:00
|
|
|
def is_windows():
|
|
|
|
return sys.platform.startswith("win32")
|
|
|
|
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
def shell(cmd):
|
2021-06-22 13:36:14 -07:00
|
|
|
try:
|
|
|
|
output = subprocess.check_output(cmd)
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
print(e.output)
|
|
|
|
raise
|
2018-11-26 12:37:24 -08:00
|
|
|
return output.decode("UTF-8").strip()
|
|
|
|
|
|
|
|
|
|
|
|
# Python
|
|
|
|
|
|
|
|
def get_python_bin_path(python_bin_path_flag):
|
|
|
|
"""Returns the path to the Python interpreter to use."""
|
2020-11-10 00:23:54 +08:00
|
|
|
path = python_bin_path_flag or sys.executable
|
|
|
|
return path.replace(os.sep, "/")
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
2020-02-17 11:24:03 -08:00
|
|
|
def get_python_version(python_bin_path):
|
|
|
|
version_output = shell(
|
|
|
|
[python_bin_path, "-c",
|
|
|
|
"import sys; print(\"{}.{}\".format(sys.version_info[0], "
|
|
|
|
"sys.version_info[1]))"])
|
|
|
|
major, minor = map(int, version_output.split("."))
|
|
|
|
return major, minor
|
|
|
|
|
|
|
|
def check_python_version(python_version):
|
2021-07-15 09:50:30 -04:00
|
|
|
if python_version < (3, 7):
|
|
|
|
print("ERROR: JAX requires Python 3.7 or newer, found ", python_version)
|
2020-02-17 11:24:03 -08:00
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
2021-05-24 11:18:37 -04:00
|
|
|
def check_numpy_version(python_bin_path):
|
|
|
|
version = shell(
|
|
|
|
[python_bin_path, "-c", "import numpy as np; print(np.__version__)"])
|
|
|
|
numpy_version = tuple(map(int, version.split('.')[:2]))
|
2021-07-29 09:18:01 -04:00
|
|
|
if numpy_version < (1, 18):
|
|
|
|
print("ERROR: JAX requires NumPy 1.18 or newer, found " + version + ".")
|
2021-05-24 11:18:37 -04:00
|
|
|
sys.exit(-1)
|
|
|
|
return version
|
|
|
|
|
|
|
|
def check_scipy_version(python_bin_path):
|
|
|
|
version = shell(
|
|
|
|
[python_bin_path, "-c", "import scipy as sp; print(sp.__version__)"])
|
|
|
|
scipy_version = tuple(map(int, version.split('.')[:2]))
|
|
|
|
if scipy_version < (1, 0):
|
|
|
|
print("ERROR: JAX requires SciPy 1.0 or newer, found " + version + ".")
|
|
|
|
sys.exit(-1)
|
|
|
|
return version
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
# Bazel
|
|
|
|
|
2021-07-23 10:39:02 -04:00
|
|
|
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/4.1.0/"
|
2021-05-18 21:53:31 -04:00
|
|
|
BazelPackage = collections.namedtuple("BazelPackage",
|
|
|
|
["base_uri", "file", "sha256"])
|
2018-11-26 12:37:24 -08:00
|
|
|
bazel_packages = {
|
2021-05-10 11:08:09 -04:00
|
|
|
("Linux", "x86_64"):
|
2018-11-26 12:37:24 -08:00
|
|
|
BazelPackage(
|
2021-05-18 21:53:31 -04:00
|
|
|
base_uri=None,
|
2021-07-23 10:39:02 -04:00
|
|
|
file="bazel-4.1.0-linux-x86_64",
|
2018-11-26 12:37:24 -08:00
|
|
|
sha256=
|
2021-07-23 10:39:02 -04:00
|
|
|
"0eb2e378d2782e7810753e2162245ad1179c1bb12f848c692b4a595b4edf779b"),
|
2021-05-10 11:08:09 -04:00
|
|
|
("Linux", "aarch64"):
|
|
|
|
BazelPackage(
|
2021-05-18 21:53:31 -04:00
|
|
|
base_uri=None,
|
2021-07-23 10:39:02 -04:00
|
|
|
file="bazel-4.1.0-linux-arm64",
|
2021-05-10 11:08:09 -04:00
|
|
|
sha256=
|
2021-07-23 10:39:02 -04:00
|
|
|
"b3834742166379e52b880319dec4699082cb26fa96cbb783087deedc5fbb5f2b"),
|
2021-05-10 11:08:09 -04:00
|
|
|
("Darwin", "x86_64"):
|
2018-11-26 12:37:24 -08:00
|
|
|
BazelPackage(
|
2021-05-18 21:53:31 -04:00
|
|
|
base_uri=None,
|
2021-07-23 10:39:02 -04:00
|
|
|
file="bazel-4.1.0-darwin-x86_64",
|
2018-11-26 12:37:24 -08:00
|
|
|
sha256=
|
2021-07-23 10:39:02 -04:00
|
|
|
"2eecc3abb0ff653ed0bffdb9fbfda7b08548c2868f13da4a995f01528db200a9"),
|
2021-05-18 21:53:31 -04:00
|
|
|
("Darwin", "arm64"):
|
|
|
|
BazelPackage(
|
2021-07-23 10:39:02 -04:00
|
|
|
base_uri=None,
|
2021-05-24 21:13:43 +01:00
|
|
|
file="bazel-4.1.0-darwin-arm64",
|
2021-05-18 21:53:31 -04:00
|
|
|
sha256=
|
2021-05-24 21:13:43 +01:00
|
|
|
"c372d39ab9dac96f7fdfc2dd649e88b05ee4c94ce3d6cf2313438ef0ca6d5ac1"),
|
2021-05-10 11:08:09 -04:00
|
|
|
("Windows", "x86_64"):
|
2020-11-17 15:50:30 -05:00
|
|
|
BazelPackage(
|
2021-05-18 21:53:31 -04:00
|
|
|
base_uri=None,
|
2021-07-23 10:39:02 -04:00
|
|
|
file="bazel-4.1.0-windows-x86_64.exe",
|
2020-11-17 15:50:30 -05:00
|
|
|
sha256=
|
2021-07-23 10:39:02 -04:00
|
|
|
"7b2077af7055b421fe31822f83c3c3c15e36ff39b69560ba2472dde92dd45b46"),
|
2018-11-26 12:37:24 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-07-23 10:39:02 -04:00
|
|
|
def download_and_verify_bazel():
|
2018-11-26 12:37:24 -08:00
|
|
|
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
|
2021-07-23 10:39:02 -04:00
|
|
|
package = bazel_packages.get((platform.system(), platform.machine()))
|
2018-11-26 12:37:24 -08:00
|
|
|
if package is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
if not os.access(package.file, os.X_OK):
|
2021-05-18 21:53:31 -04:00
|
|
|
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
|
2018-11-26 12:37:24 -08:00
|
|
|
sys.stdout.write("Downloading bazel from: {}\n".format(uri))
|
|
|
|
|
|
|
|
def progress(block_count, block_size, total_size):
|
|
|
|
if total_size <= 0:
|
|
|
|
total_size = 170**6
|
|
|
|
progress = (block_count * block_size) / total_size
|
|
|
|
num_chars = 40
|
|
|
|
progress_chars = int(num_chars * progress)
|
|
|
|
sys.stdout.write("{} [{}{}] {}%\r".format(
|
|
|
|
package.file, "#" * progress_chars,
|
|
|
|
"." * (num_chars - progress_chars), int(progress * 100.0)))
|
|
|
|
|
2020-03-16 11:01:08 -04:00
|
|
|
tmp_path, _ = urlretrieve(uri, None,
|
|
|
|
progress if sys.stdout.isatty() else None)
|
2018-11-26 12:37:24 -08:00
|
|
|
sys.stdout.write("\n")
|
|
|
|
|
|
|
|
# Verify that the downloaded Bazel binary has the expected SHA256.
|
2020-09-17 21:51:18 +05:30
|
|
|
with open(tmp_path, "rb") as downloaded_file:
|
|
|
|
contents = downloaded_file.read()
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
digest = hashlib.sha256(contents).hexdigest()
|
|
|
|
if digest != package.sha256:
|
|
|
|
print(
|
|
|
|
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
|
|
|
.format(package.sha256, digest))
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
# Write the file as the bazel file name.
|
2020-09-17 21:51:18 +05:30
|
|
|
with open(package.file, "wb") as out_file:
|
|
|
|
out_file.write(contents)
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
# Mark the file as executable.
|
|
|
|
st = os.stat(package.file)
|
|
|
|
os.chmod(package.file,
|
|
|
|
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
|
|
|
|
2020-11-17 15:50:30 -05:00
|
|
|
return os.path.join(".", package.file)
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
2021-07-23 10:39:02 -04:00
|
|
|
def get_bazel_paths(bazel_path_flag):
|
2021-02-13 16:32:21 +03:00
|
|
|
"""Yields a sequence of guesses about bazel path. Some of sequence elements
|
2021-02-01 14:00:44 +03:00
|
|
|
can be None. The resulting iterator is lazy and potentially has a side
|
|
|
|
effects."""
|
|
|
|
yield bazel_path_flag
|
|
|
|
yield which("bazel")
|
2021-07-23 10:39:02 -04:00
|
|
|
yield download_and_verify_bazel()
|
2021-02-01 14:00:44 +03:00
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
|
2021-07-23 10:39:02 -04:00
|
|
|
def get_bazel_path(bazel_path_flag):
|
2021-02-01 14:00:44 +03:00
|
|
|
"""Returns the path to a Bazel binary, downloading Bazel if not found. Also,
|
2021-02-13 16:32:21 +03:00
|
|
|
it checks Bazel's version at lease newer than 2.0.0.
|
2018-11-26 12:37:24 -08:00
|
|
|
|
2021-02-01 14:00:44 +03:00
|
|
|
NOTE Manual version check is reasonably only for bazel < 2.0.0. Newer bazel
|
|
|
|
releases performs version check against .bazelversion (see for details
|
|
|
|
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
|
|
|
|
"""
|
2021-07-23 10:39:02 -04:00
|
|
|
for path in filter(None, get_bazel_paths(bazel_path_flag)):
|
2021-02-13 16:32:21 +03:00
|
|
|
if check_bazel_version(path):
|
2021-02-01 14:00:44 +03:00
|
|
|
return path
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
print("Cannot find or download bazel. Please install bazel.")
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
2021-02-13 16:32:21 +03:00
|
|
|
def check_bazel_version(bazel_path):
|
2021-02-01 14:00:44 +03:00
|
|
|
try:
|
|
|
|
version_output = shell([bazel_path, "--bazelrc=/dev/null", "version"])
|
|
|
|
except subprocess.CalledProcessError:
|
|
|
|
return False
|
2018-11-26 12:37:24 -08:00
|
|
|
match = re.search("Build label: *([0-9\\.]+)[^0-9\\.]", version_output)
|
|
|
|
if match is None:
|
2021-02-01 14:00:44 +03:00
|
|
|
return False
|
2018-11-26 12:37:24 -08:00
|
|
|
actual_ints = [int(x) for x in match.group(1).split(".")]
|
2021-02-16 11:32:50 -08:00
|
|
|
return actual_ints >= [2, 0, 0]
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
|
|
|
BAZELRC_TEMPLATE = """
|
2020-02-22 09:45:24 -08:00
|
|
|
# Flag to enable remote config
|
|
|
|
common --experimental_repo_remote_exec
|
|
|
|
|
2019-11-01 10:41:51 -04:00
|
|
|
build --repo_env PYTHON_BIN_PATH="{python_bin_path}"
|
2020-11-09 16:25:20 -05:00
|
|
|
build --action_env=PYENV_ROOT
|
2018-11-26 12:37:24 -08:00
|
|
|
build --python_path="{python_bin_path}"
|
2019-11-01 10:41:51 -04:00
|
|
|
build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
|
2020-04-15 10:57:53 -04:00
|
|
|
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
|
2020-12-05 00:07:04 +01:00
|
|
|
build --repo_env TF_NEED_ROCM="{tf_need_rocm}"
|
|
|
|
build --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"
|
2021-02-16 14:12:01 -05:00
|
|
|
build:posix --copt=-Wno-sign-compare
|
2018-12-14 12:21:25 -05:00
|
|
|
build -c opt
|
2021-02-16 14:12:01 -05:00
|
|
|
build:avx_posix --copt=-mavx
|
|
|
|
build:avx_posix --host_copt=-mavx
|
|
|
|
build:avx_windows --copt=/arch=AVX
|
|
|
|
build:native_arch_posix --copt=-march=native
|
|
|
|
build:native_arch_posix --host_copt=-march=native
|
2018-12-14 12:21:25 -05:00
|
|
|
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
|
2019-01-07 10:59:08 -05:00
|
|
|
|
2019-03-26 22:26:12 -04:00
|
|
|
# Sets the default Apple platform to macOS.
|
|
|
|
build --apple_platform_type=macos
|
2019-12-03 11:59:31 -05:00
|
|
|
build --macos_minimum_os=10.9
|
2019-03-26 22:26:12 -04:00
|
|
|
|
2019-11-24 13:06:23 -05:00
|
|
|
# Make Bazel print out all options from rc files.
|
|
|
|
build --announce_rc
|
|
|
|
|
2020-06-25 14:37:14 -04:00
|
|
|
build --define open_source_build=true
|
|
|
|
|
2019-01-07 10:59:08 -05:00
|
|
|
# Disable enabled-by-default TensorFlow features that we don't care about.
|
2021-02-16 14:12:01 -05:00
|
|
|
build:posix --define=no_aws_support=true
|
|
|
|
build:posix --define=no_gcp_support=true
|
|
|
|
build:posix --define=no_hdfs_support=true
|
2019-01-07 10:59:08 -05:00
|
|
|
build --define=no_kafka_support=true
|
|
|
|
build --define=no_ignite_support=true
|
2019-03-31 10:31:47 -07:00
|
|
|
build --define=grpc_no_ares=true
|
2019-03-06 15:20:43 -08:00
|
|
|
|
|
|
|
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
2021-03-03 08:55:08 +01:00
|
|
|
build:cuda --@local_config_cuda//:enable_cuda
|
2019-08-01 21:43:03 -04:00
|
|
|
|
2020-12-05 00:07:04 +01:00
|
|
|
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
|
|
|
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
|
|
|
build:nonccl --define=no_nccl_support=true
|
|
|
|
|
2019-08-01 21:43:03 -04:00
|
|
|
build --spawn_strategy=standalone
|
|
|
|
build --strategy=Genrule=standalone
|
2019-09-28 15:05:41 -04:00
|
|
|
|
2020-11-10 00:23:54 +08:00
|
|
|
build --enable_platform_specific_config
|
|
|
|
|
|
|
|
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
|
|
|
|
# _USE_MATH_DEFINES is defined.
|
|
|
|
build:windows --copt=/D_USE_MATH_DEFINES
|
|
|
|
build:windows --host_copt=/D_USE_MATH_DEFINES
|
|
|
|
|
|
|
|
# Make sure to include as little of windows.h as possible
|
|
|
|
build:windows --copt=-DWIN32_LEAN_AND_MEAN
|
|
|
|
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
|
|
|
|
build:windows --copt=-DNOGDI
|
|
|
|
build:windows --host_copt=-DNOGDI
|
|
|
|
|
|
|
|
# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/
|
|
|
|
# otherwise, there will be some compiling error due to preprocessing.
|
|
|
|
build:windows --copt=/Zc:preprocessor
|
|
|
|
|
2021-02-16 14:12:01 -05:00
|
|
|
build:posix --cxxopt=-std=c++14
|
|
|
|
build:posix --host_cxxopt=-std=c++14
|
2020-11-10 00:23:54 +08:00
|
|
|
|
|
|
|
build:windows --cxxopt=/std:c++14
|
|
|
|
build:windows --host_cxxopt=/std:c++14
|
|
|
|
|
2021-02-16 14:12:01 -05:00
|
|
|
build:linux --config=posix
|
|
|
|
build:macos --config=posix
|
|
|
|
|
2020-11-10 00:23:54 +08:00
|
|
|
# Generate PDB files, to generate useful PDBs, in opt compilation_mode
|
|
|
|
# --copt /Z7 is needed.
|
|
|
|
build:windows --linkopt=/DEBUG
|
|
|
|
build:windows --host_linkopt=/DEBUG
|
|
|
|
build:windows --linkopt=/OPT:REF
|
|
|
|
build:windows --host_linkopt=/OPT:REF
|
|
|
|
build:windows --linkopt=/OPT:ICF
|
|
|
|
build:windows --host_linkopt=/OPT:ICF
|
2020-11-24 23:44:26 +08:00
|
|
|
build:windows --experimental_strict_action_env=true
|
2020-02-22 09:45:24 -08:00
|
|
|
|
|
|
|
# Suppress all warning messages.
|
|
|
|
build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
2020-11-19 22:50:08 -05:00
|
|
|
|
|
|
|
# Workaround for gcc 10+ warnings related to upb.
|
|
|
|
# See https://github.com/tensorflow/tensorflow/issues/39467
|
|
|
|
build:linux --copt=-Wno-stringop-truncation
|
2018-11-26 12:37:24 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
2019-04-25 16:19:53 -07:00
|
|
|
|
2020-11-10 00:23:54 +08:00
|
|
|
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
|
2021-07-12 16:33:12 -04:00
|
|
|
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
|
|
|
|
cpu=None, **kwargs):
|
2020-09-17 21:51:18 +05:30
|
|
|
with open("../.bazelrc", "w") as f:
|
|
|
|
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
2021-07-18 22:55:34 +08:00
|
|
|
tf_cuda_paths = []
|
2020-09-17 21:51:18 +05:30
|
|
|
if cuda_toolkit_path:
|
2021-07-18 22:55:34 +08:00
|
|
|
tf_cuda_paths.append(cuda_toolkit_path)
|
2020-09-17 21:51:18 +05:30
|
|
|
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
|
|
|
.format(cuda_toolkit_path=cuda_toolkit_path))
|
|
|
|
if cudnn_install_path:
|
2021-07-18 22:55:34 +08:00
|
|
|
tf_cuda_paths.append(cudnn_install_path)
|
2020-09-17 21:51:18 +05:30
|
|
|
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
|
|
|
.format(cudnn_install_path=cudnn_install_path))
|
2021-07-18 22:55:34 +08:00
|
|
|
if len(tf_cuda_paths):
|
|
|
|
f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n"
|
|
|
|
.format(tf_cuda_paths=",".join(tf_cuda_paths)))
|
2020-11-10 00:23:54 +08:00
|
|
|
if cuda_version:
|
|
|
|
f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n"
|
|
|
|
.format(cuda_version=cuda_version))
|
|
|
|
if cudnn_version:
|
|
|
|
f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n"
|
|
|
|
.format(cudnn_version=cudnn_version))
|
2020-12-05 00:07:04 +01:00
|
|
|
if rocm_toolkit_path:
|
|
|
|
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
|
|
|
|
.format(rocm_toolkit_path=rocm_toolkit_path))
|
2021-07-12 16:33:12 -04:00
|
|
|
if cpu is not None:
|
|
|
|
f.write("build --distinct_host_configuration=true\n")
|
|
|
|
f.write(f"build --cpu={cpu}\n")
|
|
|
|
else:
|
|
|
|
f.write("build --distinct_host_configuration=false\n")
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
BANNER = r"""
|
2018-12-05 08:22:27 -08:00
|
|
|
_ _ __ __
|
|
|
|
| | / \ \ \/ /
|
|
|
|
_ | |/ _ \ \ /
|
|
|
|
| |_| / ___ \/ \
|
|
|
|
\___/_/ \/_/\_\
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
EPILOG = """
|
|
|
|
|
2018-12-06 21:35:03 -05:00
|
|
|
From the 'build' directory in the JAX repository, run
|
|
|
|
python build.py
|
2018-11-26 12:37:24 -08:00
|
|
|
or
|
2018-12-06 21:35:03 -05:00
|
|
|
python3 build.py
|
|
|
|
to download and build JAX's XLA (jaxlib) dependency.
|
2018-11-26 12:37:24 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_string_as_bool(s):
|
|
|
|
"""Parses a string as a boolean argument."""
|
|
|
|
lower = s.lower()
|
|
|
|
if lower == "true":
|
|
|
|
return True
|
|
|
|
elif lower == "false":
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
raise ValueError("Expected either 'true' or 'false'; got {}".format(s))
|
|
|
|
|
|
|
|
|
|
|
|
def add_boolean_argument(parser, name, default=False, help_str=None):
|
|
|
|
"""Creates a boolean flag."""
|
|
|
|
group = parser.add_mutually_exclusive_group()
|
|
|
|
group.add_argument(
|
|
|
|
"--" + name,
|
|
|
|
nargs="?",
|
|
|
|
default=default,
|
|
|
|
const=True,
|
|
|
|
type=_parse_string_as_bool,
|
|
|
|
help=help_str)
|
|
|
|
group.add_argument("--no" + name, dest=name, action="store_false")
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2020-11-20 09:10:02 -05:00
|
|
|
cwd = os.getcwd()
|
2018-11-26 12:37:24 -08:00
|
|
|
parser = argparse.ArgumentParser(
|
2020-11-20 09:10:02 -05:00
|
|
|
description="Builds jaxlib from source.", epilog=EPILOG)
|
2018-11-26 12:37:24 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--bazel_path",
|
|
|
|
help="Path to the Bazel binary to use. The default is to find bazel via "
|
|
|
|
"the PATH; if none is found, downloads a fresh copy of bazel from "
|
|
|
|
"GitHub.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--python_bin_path",
|
|
|
|
help="Path to Python binary to use. The default is the Python "
|
|
|
|
"interpreter used to run the build script.")
|
2021-02-16 14:12:01 -05:00
|
|
|
parser.add_argument(
|
|
|
|
"--target_cpu_features",
|
|
|
|
choices=["release", "native", "default"],
|
|
|
|
default="release",
|
|
|
|
help="What CPU features should we target? 'release' enables CPU "
|
|
|
|
"features that should be enabled for a release build, which on "
|
|
|
|
"x86-64 architectures enables AVX. 'native' enables "
|
|
|
|
"-march=native, which generates code targeted to use all "
|
|
|
|
"features of the current machine. 'default' means don't opt-in "
|
|
|
|
"to any architectural features and use whatever the C compiler "
|
|
|
|
"generates by default.")
|
2018-12-14 12:21:25 -05:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_mkl_dnn",
|
|
|
|
default=True,
|
|
|
|
help_str="Should we build with MKL-DNN enabled?")
|
2018-11-26 12:37:24 -08:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_cuda",
|
|
|
|
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
|
2021-03-08 16:19:19 -08:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_tpu",
|
|
|
|
help_str="Should we build with Cloud TPU support enabled?")
|
2020-12-05 00:07:04 +01:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_rocm",
|
|
|
|
help_str="Should we build with ROCm enabled?")
|
2021-07-13 09:01:14 -04:00
|
|
|
add_boolean_argument(
|
|
|
|
parser,
|
|
|
|
"enable_nccl",
|
|
|
|
default=True,
|
|
|
|
help_str="Should we build with NCCL enabled? Has non effect for non-CUDA "
|
|
|
|
"builds.")
|
2018-11-26 12:37:24 -08:00
|
|
|
parser.add_argument(
|
|
|
|
"--cuda_path",
|
2019-04-25 16:19:53 -07:00
|
|
|
default=None,
|
2018-11-26 12:37:24 -08:00
|
|
|
help="Path to the CUDA toolkit.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--cudnn_path",
|
2019-04-25 16:19:53 -07:00
|
|
|
default=None,
|
2018-11-26 12:37:24 -08:00
|
|
|
help="Path to CUDNN libraries.")
|
2020-11-10 00:23:54 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--cuda_version",
|
|
|
|
default=None,
|
|
|
|
help="CUDA toolkit version, e.g., 11.1")
|
|
|
|
parser.add_argument(
|
|
|
|
"--cudnn_version",
|
|
|
|
default=None,
|
|
|
|
help="CUDNN version, e.g., 8")
|
2020-04-15 10:57:53 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--cuda_compute_capabilities",
|
|
|
|
default="3.5,5.2,6.0,6.1,7.0",
|
|
|
|
help="A comma-separated list of CUDA compute capabilities to support.")
|
2020-12-05 00:07:04 +01:00
|
|
|
parser.add_argument(
|
|
|
|
"--rocm_path",
|
|
|
|
default=None,
|
|
|
|
help="Path to the ROCm toolkit.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--rocm_amdgpu_targets",
|
|
|
|
default="gfx803,gfx900,gfx906,gfx1010",
|
|
|
|
help="A comma-separated list of ROCm amdgpu targets to support.")
|
2019-08-08 16:14:45 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--bazel_startup_options",
|
|
|
|
action="append", default=[],
|
|
|
|
help="Additional startup options to pass to bazel.")
|
|
|
|
parser.add_argument(
|
|
|
|
"--bazel_options",
|
|
|
|
action="append", default=[],
|
|
|
|
help="Additional options to pass to bazel.")
|
2020-11-20 09:10:02 -05:00
|
|
|
parser.add_argument(
|
|
|
|
"--output_path",
|
|
|
|
default=os.path.join(cwd, "dist"),
|
|
|
|
help="Directory to which the jaxlib wheel should be written")
|
2021-07-12 16:33:12 -04:00
|
|
|
parser.add_argument(
|
|
|
|
"--target_cpu",
|
|
|
|
default=None,
|
|
|
|
help="CPU platform to target. Default is the same as the host machine. "
|
|
|
|
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
|
2018-11-26 12:37:24 -08:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
2020-11-10 00:23:54 +08:00
|
|
|
if is_windows() and args.enable_cuda:
|
|
|
|
if args.cuda_version is None:
|
|
|
|
parser.error("--cuda_version is needed for Windows CUDA build.")
|
|
|
|
if args.cudnn_version is None:
|
|
|
|
parser.error("--cudnn_version is needed for Windows CUDA build.")
|
|
|
|
|
2020-12-05 00:07:04 +01:00
|
|
|
if args.enable_cuda and args.enable_rocm:
|
|
|
|
parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.")
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
print(BANNER)
|
2020-11-20 09:10:02 -05:00
|
|
|
|
|
|
|
output_path = os.path.abspath(args.output_path)
|
2018-12-07 04:59:03 -08:00
|
|
|
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
|
2018-11-26 12:37:24 -08:00
|
|
|
|
2021-07-12 16:33:12 -04:00
|
|
|
host_cpu = platform.machine()
|
|
|
|
wheel_cpus = {
|
|
|
|
"darwin_arm64": "arm64",
|
|
|
|
"darwin_x86_64": "x86_64",
|
2021-07-23 10:39:02 -04:00
|
|
|
"ppc": "ppc64le",
|
2021-07-12 16:33:12 -04:00
|
|
|
}
|
|
|
|
# TODO(phawkins): support other bazel cpu overrides.
|
|
|
|
wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None
|
|
|
|
else host_cpu)
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
# Find a working Bazel.
|
2021-07-23 10:39:02 -04:00
|
|
|
bazel_path = get_bazel_path(args.bazel_path)
|
2018-11-26 12:37:24 -08:00
|
|
|
print("Bazel binary path: {}".format(bazel_path))
|
|
|
|
|
|
|
|
python_bin_path = get_python_bin_path(args.python_bin_path)
|
|
|
|
print("Python binary path: {}".format(python_bin_path))
|
2020-02-17 11:24:03 -08:00
|
|
|
python_version = get_python_version(python_bin_path)
|
|
|
|
print("Python version: {}".format(".".join(map(str, python_version))))
|
|
|
|
check_python_version(python_version)
|
2018-11-26 12:37:24 -08:00
|
|
|
|
2021-05-24 11:18:37 -04:00
|
|
|
numpy_version = check_numpy_version(python_bin_path)
|
|
|
|
print("NumPy version: {}".format(numpy_version))
|
|
|
|
scipy_version = check_scipy_version(python_bin_path)
|
|
|
|
print("SciPy version: {}".format(scipy_version))
|
|
|
|
|
2018-12-14 12:21:25 -05:00
|
|
|
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
2021-07-12 16:33:12 -04:00
|
|
|
print("Target CPU: {}".format(wheel_cpu))
|
2021-02-16 14:12:01 -05:00
|
|
|
print("Target CPU features: {}".format(args.target_cpu_features))
|
2018-12-14 12:21:25 -05:00
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
cuda_toolkit_path = args.cuda_path
|
|
|
|
cudnn_install_path = args.cudnn_path
|
2020-12-05 00:07:04 +01:00
|
|
|
rocm_toolkit_path = args.rocm_path
|
2018-11-26 12:37:24 -08:00
|
|
|
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
|
|
|
|
if args.enable_cuda:
|
2019-04-25 16:19:53 -07:00
|
|
|
if cuda_toolkit_path:
|
|
|
|
print("CUDA toolkit path: {}".format(cuda_toolkit_path))
|
|
|
|
if cudnn_install_path:
|
|
|
|
print("CUDNN library path: {}".format(cudnn_install_path))
|
2020-04-15 10:57:53 -04:00
|
|
|
print("CUDA compute capabilities: {}".format(args.cuda_compute_capabilities))
|
2020-11-10 00:23:54 +08:00
|
|
|
if args.cuda_version:
|
|
|
|
print("CUDA version: {}".format(args.cuda_version))
|
|
|
|
if args.cudnn_version:
|
|
|
|
print("CUDNN version: {}".format(args.cudnn_version))
|
2021-07-13 09:01:14 -04:00
|
|
|
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))
|
2020-12-05 00:07:04 +01:00
|
|
|
|
2021-03-08 16:19:19 -08:00
|
|
|
print("TPU enabled: {}".format("yes" if args.enable_tpu else "no"))
|
|
|
|
|
2020-12-05 00:07:04 +01:00
|
|
|
print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
|
|
|
|
if args.enable_rocm:
|
|
|
|
if rocm_toolkit_path:
|
|
|
|
print("ROCm toolkit path: {}".format(rocm_toolkit_path))
|
|
|
|
print("ROCm amdgpu targets: {}".format(args.rocm_amdgpu_targets))
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
write_bazelrc(
|
|
|
|
python_bin_path=python_bin_path,
|
|
|
|
tf_need_cuda=1 if args.enable_cuda else 0,
|
2020-12-05 00:07:04 +01:00
|
|
|
tf_need_rocm=1 if args.enable_rocm else 0,
|
2018-11-26 12:37:24 -08:00
|
|
|
cuda_toolkit_path=cuda_toolkit_path,
|
2020-04-15 10:57:53 -04:00
|
|
|
cudnn_install_path=cudnn_install_path,
|
2020-11-10 00:23:54 +08:00
|
|
|
cuda_compute_capabilities=args.cuda_compute_capabilities,
|
|
|
|
cuda_version=args.cuda_version,
|
2020-12-05 00:07:04 +01:00
|
|
|
cudnn_version=args.cudnn_version,
|
|
|
|
rocm_toolkit_path=rocm_toolkit_path,
|
|
|
|
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
|
2021-07-12 16:33:12 -04:00
|
|
|
cpu=args.target_cpu,
|
|
|
|
)
|
|
|
|
|
2018-11-26 12:37:24 -08:00
|
|
|
|
2018-12-06 21:35:03 -05:00
|
|
|
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
2019-08-08 16:14:45 -04:00
|
|
|
config_args = args.bazel_options
|
2020-02-22 09:45:24 -08:00
|
|
|
config_args += ["--config=short_logs"]
|
2021-02-16 14:12:01 -05:00
|
|
|
if args.target_cpu_features == "release":
|
2021-07-12 16:33:12 -04:00
|
|
|
if wheel_cpu == "x86_64":
|
2021-02-16 14:12:01 -05:00
|
|
|
config_args += ["--config=avx_windows" if is_windows()
|
|
|
|
else "--config=avx_posix"]
|
|
|
|
elif args.target_cpu_features == "native":
|
|
|
|
if is_windows():
|
|
|
|
print("--target_cpu_features=native is not supported on Windows; ignoring.")
|
|
|
|
else:
|
|
|
|
config_args += ["--config=native_arch_posix"]
|
|
|
|
|
2018-12-14 12:21:25 -05:00
|
|
|
if args.enable_mkl_dnn:
|
|
|
|
config_args += ["--config=mkl_open_source_only"]
|
2019-03-06 15:20:43 -08:00
|
|
|
if args.enable_cuda:
|
|
|
|
config_args += ["--config=cuda"]
|
2019-11-24 13:06:10 -05:00
|
|
|
config_args += ["--define=xla_python_enable_gpu=true"]
|
2021-07-13 09:01:14 -04:00
|
|
|
if not args.enable_nccl:
|
|
|
|
config_args += ["--config=nonccl"]
|
2021-03-08 16:19:19 -08:00
|
|
|
if args.enable_tpu:
|
|
|
|
config_args += ["--define=with_tpu_support=true"]
|
2020-12-05 00:07:04 +01:00
|
|
|
if args.enable_rocm:
|
|
|
|
config_args += ["--config=rocm"]
|
|
|
|
config_args += ["--config=nonccl"]
|
|
|
|
config_args += ["--define=xla_python_enable_gpu=true"]
|
2019-11-24 13:06:10 -05:00
|
|
|
command = ([bazel_path] + args.bazel_startup_options +
|
2019-08-08 16:14:45 -04:00
|
|
|
["run", "--verbose_failures=true"] + config_args +
|
2020-11-20 09:10:02 -05:00
|
|
|
[":build_wheel", "--",
|
2021-07-12 16:33:12 -04:00
|
|
|
f"--output_path={output_path}",
|
|
|
|
f"--cpu={wheel_cpu}"])
|
2019-11-24 13:06:10 -05:00
|
|
|
print(" ".join(command))
|
|
|
|
shell(command)
|
2019-05-20 20:53:44 -04:00
|
|
|
shell([bazel_path, "shutdown"])
|
2018-11-26 12:37:24 -08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|