Move utility functions in build.py to utils.py

This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
This commit is contained in:
Nitin Srinivasan 2024-10-30 09:59:56 -07:00 committed by jax authors
parent d2f5804449
commit da994d3552
2 changed files with 283 additions and 240 deletions

View File

@ -16,226 +16,17 @@
#
# Helper script for building JAX's libjax easily.
import argparse
import collections
import hashlib
import logging
import os
import pathlib
import platform
import re
import shutil
import stat
import subprocess
import sys
import textwrap
import urllib.request
from tools import utils
logger = logging.getLogger(__name__)
def is_windows():
return sys.platform.startswith("win32")
def shell(cmd):
try:
logger.info("shell(): %s", cmd)
output = subprocess.check_output(cmd)
except subprocess.CalledProcessError as e:
logger.info("subprocess raised: %s", e)
if e.output: print(e.output)
raise
except Exception as e:
logger.info("subprocess raised: %s", e)
raise
return output.decode("UTF-8").strip()
# Python
def get_python_bin_path(python_bin_path_flag):
"""Returns the path to the Python interpreter to use."""
path = python_bin_path_flag or sys.executable
return path.replace(os.sep, "/")
def get_python_version(python_bin_path):
version_output = shell(
[python_bin_path, "-c",
("import sys; print(\"{}.{}\".format(sys.version_info[0], "
"sys.version_info[1]))")])
major, minor = map(int, version_output.split("."))
return major, minor
def check_python_version(python_version):
if python_version < (3, 10):
print("ERROR: JAX requires Python 3.10 or newer, found ", python_version)
sys.exit(-1)
def get_githash():
try:
return subprocess.run(
["git", "rev-parse", "HEAD"],
encoding='utf-8',
capture_output=True).stdout.strip()
except OSError:
return ""
# Bazel
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
BazelPackage = collections.namedtuple("BazelPackage",
["base_uri", "file", "sha256"])
bazel_packages = {
("Linux", "x86_64"):
BazelPackage(
base_uri=None,
file="bazel-6.5.0-linux-x86_64",
sha256=
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"),
("Linux", "aarch64"):
BazelPackage(
base_uri=None,
file="bazel-6.5.0-linux-arm64",
sha256=
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"),
("Darwin", "x86_64"):
BazelPackage(
base_uri=None,
file="bazel-6.5.0-darwin-x86_64",
sha256=
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"),
("Darwin", "arm64"):
BazelPackage(
base_uri=None,
file="bazel-6.5.0-darwin-arm64",
sha256=
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"),
("Windows", "AMD64"):
BazelPackage(
base_uri=None,
file="bazel-6.5.0-windows-x86_64.exe",
sha256=
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"),
}
def download_and_verify_bazel():
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
package = bazel_packages.get((platform.system(), platform.machine()))
if package is None:
return None
if not os.access(package.file, os.X_OK):
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
sys.stdout.write(f"Downloading bazel from: {uri}\n")
def progress(block_count, block_size, total_size):
if total_size <= 0:
total_size = 170**6
progress = (block_count * block_size) / total_size
num_chars = 40
progress_chars = int(num_chars * progress)
sys.stdout.write("{} [{}{}] {}%\r".format(
package.file, "#" * progress_chars,
"." * (num_chars - progress_chars), int(progress * 100.0)))
tmp_path, _ = urllib.request.urlretrieve(
uri, None, progress if sys.stdout.isatty() else None
)
sys.stdout.write("\n")
# Verify that the downloaded Bazel binary has the expected SHA256.
with open(tmp_path, "rb") as downloaded_file:
contents = downloaded_file.read()
digest = hashlib.sha256(contents).hexdigest()
if digest != package.sha256:
print(
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
.format(package.sha256, digest))
sys.exit(-1)
# Write the file as the bazel file name.
with open(package.file, "wb") as out_file:
out_file.write(contents)
# Mark the file as executable.
st = os.stat(package.file)
os.chmod(package.file,
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
return os.path.join(".", package.file)
def get_bazel_paths(bazel_path_flag):
"""Yields a sequence of guesses about bazel path. Some of sequence elements
can be None. The resulting iterator is lazy and potentially has a side
effects."""
yield bazel_path_flag
yield shutil.which("bazel")
yield download_and_verify_bazel()
def get_bazel_path(bazel_path_flag):
"""Returns the path to a Bazel binary, downloading Bazel if not found. Also,
checks Bazel's version is at least newer than 6.5.0
A manual version check is needed only for really old bazel versions.
Newer bazel releases perform their own version check against .bazelversion
(see for details
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
"""
for path in filter(None, get_bazel_paths(bazel_path_flag)):
version = get_bazel_version(path)
if version is not None and version >= (6, 5, 0):
return path, ".".join(map(str, version))
print("Cannot find or download a suitable version of bazel."
"Please install bazel >= 6.5.0.")
sys.exit(-1)
def get_bazel_version(bazel_path):
try:
version_output = shell([bazel_path, "--version"])
except (subprocess.CalledProcessError, OSError):
return None
match = re.search(r"bazel *([0-9\\.]+)", version_output)
if match is None:
return None
return tuple(int(x) for x in match.group(1).split("."))
def get_clang_path_or_exit():
which_clang_output = shutil.which("clang")
if which_clang_output:
# If we've found a clang on the path, need to get the fully resolved path
# to ensure that system headers are found.
return str(pathlib.Path(which_clang_output).resolve())
else:
print(
"--clang_path is unset and clang cannot be found"
" on the PATH. Please pass --clang_path directly."
)
sys.exit(-1)
def get_clang_major_version(clang_path):
clang_version_proc = subprocess.run(
[clang_path, "-E", "-P", "-"],
input="__clang_major__",
check=True,
capture_output=True,
text=True,
)
major_version = int(clang_version_proc.stdout)
return major_version
def write_bazelrc(*, remote_build,
cuda_version, cudnn_version, rocm_toolkit_path,
cpu, cuda_compute_capabilities,
@ -272,10 +63,10 @@ def write_bazelrc(*, remote_build,
if target_cpu_features == "release":
if wheel_cpu == "x86_64":
f.write("build --config=avx_windows\n" if is_windows()
f.write("build --config=avx_windows\n" if utils.is_windows()
else "build --config=avx_posix\n")
elif target_cpu_features == "native":
if is_windows():
if utils.is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
f.write("build --config=native_arch_posix\n")
@ -575,18 +366,18 @@ def main():
else host_cpu)
# Find a working Bazel.
bazel_path, bazel_version = get_bazel_path(args.bazel_path)
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
print(f"Bazel binary path: {bazel_path}")
print(f"Bazel version: {bazel_version}")
if args.python_version:
python_version = args.python_version
else:
python_bin_path = get_python_bin_path(args.python_bin_path)
python_bin_path = utils.get_python_bin_path(args.python_bin_path)
print(f"Python binary path: {python_bin_path}")
python_version = get_python_version(python_bin_path)
python_version = utils.get_python_version(python_bin_path)
print("Python version: {}".format(".".join(map(str, python_version))))
check_python_version(python_version)
utils.check_python_version(python_version)
python_version = ".".join(map(str, python_version))
print("Use clang: {}".format("yes" if args.use_clang else "no"))
@ -594,9 +385,9 @@ def main():
clang_major_version = None
if args.use_clang:
if not clang_path:
clang_path = get_clang_path_or_exit()
clang_path = utils.get_clang_path_or_exit()
print(f"clang path: {clang_path}")
clang_major_version = get_clang_major_version(clang_path)
clang_major_version = utils.get_clang_major_version(clang_path)
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
print(f"Target CPU: {wheel_cpu}")
@ -648,7 +439,7 @@ def main():
update_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true", task, *args.bazel_options])
print(" ".join(update_command))
shell(update_command)
utils.shell(update_command)
return
if args.configure_only:
@ -675,27 +466,29 @@ def main():
if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin:
build_cpu_wheel_command = [
*command_base,
"//jaxlib/tools:build_wheel", "--",
f"--output_path={output_path_jaxlib}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}"
*command_base,
"//jaxlib/tools:build_wheel",
"--",
f"--output_path={output_path_jaxlib}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.build_gpu_plugin:
build_cpu_wheel_command.append("--skip_gpu_kernels")
if args.editable:
build_cpu_wheel_command.append("--editable")
print(" ".join(build_cpu_wheel_command))
shell(build_cpu_wheel_command)
utils.shell(build_cpu_wheel_command)
if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \
(args.build_gpu_kernel_plugin == "rocm"):
build_gpu_kernels_command = [
*command_base,
"//jaxlib/tools:build_gpu_kernels_wheel", "--",
f"--output_path={output_path_jax_kernel}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
*command_base,
"//jaxlib/tools:build_gpu_kernels_wheel",
"--",
f"--output_path={output_path_jax_kernel}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.enable_cuda:
build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}")
@ -708,15 +501,16 @@ def main():
if args.editable:
build_gpu_kernels_command.append("--editable")
print(" ".join(build_gpu_kernels_command))
shell(build_gpu_kernels_command)
utils.shell(build_gpu_kernels_command)
if args.build_gpu_plugin or args.build_gpu_pjrt_plugin:
build_pjrt_plugin_command = [
*command_base,
"//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path_jax_pjrt}",
f"--jaxlib_git_hash={get_githash()}",
f"--cpu={wheel_cpu}",
*command_base,
"//jaxlib/tools:build_gpu_plugin_wheel",
"--",
f"--output_path={output_path_jax_pjrt}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.enable_cuda:
build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}")
@ -729,9 +523,9 @@ def main():
if args.editable:
build_pjrt_plugin_command.append("--editable")
print(" ".join(build_pjrt_plugin_command))
shell(build_pjrt_plugin_command)
utils.shell(build_pjrt_plugin_command)
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
if __name__ == "__main__":

249
build/tools/utils.py Normal file
View File

@ -0,0 +1,249 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Helper script for tools/utilities used by the JAX build CLI.
import collections
import hashlib
import logging
import os
import pathlib
import platform
import re
import shutil
import stat
import subprocess
import sys
import urllib.request
logger = logging.getLogger(__name__)
def is_windows():
return sys.platform.startswith("win32")
def shell(cmd):
try:
logger.info("shell(): %s", cmd)
output = subprocess.check_output(cmd)
except subprocess.CalledProcessError as e:
logger.info("subprocess raised: %s", e)
if e.output:
print(e.output)
raise
except Exception as e:
logger.info("subprocess raised: %s", e)
raise
return output.decode("UTF-8").strip()
# Bazel
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
BazelPackage = collections.namedtuple(
"BazelPackage", ["base_uri", "file", "sha256"]
)
bazel_packages = {
("Linux", "x86_64"): BazelPackage(
base_uri=None,
file="bazel-6.5.0-linux-x86_64",
sha256=(
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
),
),
("Linux", "aarch64"): BazelPackage(
base_uri=None,
file="bazel-6.5.0-linux-arm64",
sha256=(
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
),
),
("Darwin", "x86_64"): BazelPackage(
base_uri=None,
file="bazel-6.5.0-darwin-x86_64",
sha256=(
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
),
),
("Darwin", "arm64"): BazelPackage(
base_uri=None,
file="bazel-6.5.0-darwin-arm64",
sha256=(
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
),
),
("Windows", "AMD64"): BazelPackage(
base_uri=None,
file="bazel-6.5.0-windows-x86_64.exe",
sha256=(
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
),
),
}
def download_and_verify_bazel():
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
package = bazel_packages.get((platform.system(), platform.machine()))
if package is None:
return None
if not os.access(package.file, os.X_OK):
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
sys.stdout.write(f"Downloading bazel from: {uri}\n")
def progress(block_count, block_size, total_size):
if total_size <= 0:
total_size = 170**6
progress = (block_count * block_size) / total_size
num_chars = 40
progress_chars = int(num_chars * progress)
sys.stdout.write(
"{} [{}{}] {}%\r".format(
package.file,
"#" * progress_chars,
"." * (num_chars - progress_chars),
int(progress * 100.0),
)
)
tmp_path, _ = urllib.request.urlretrieve(
uri, None, progress if sys.stdout.isatty() else None
)
sys.stdout.write("\n")
# Verify that the downloaded Bazel binary has the expected SHA256.
with open(tmp_path, "rb") as downloaded_file:
contents = downloaded_file.read()
digest = hashlib.sha256(contents).hexdigest()
if digest != package.sha256:
print(
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
.format(package.sha256, digest)
)
sys.exit(-1)
# Write the file as the bazel file name.
with open(package.file, "wb") as out_file:
out_file.write(contents)
# Mark the file as executable.
st = os.stat(package.file)
os.chmod(
package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
)
return os.path.join(".", package.file)
def get_bazel_paths(bazel_path_flag):
"""Yields a sequence of guesses about bazel path.
Some of sequence elements can be None. The resulting iterator is lazy and
potentially has a side effects.
"""
yield bazel_path_flag
yield shutil.which("bazel")
yield download_and_verify_bazel()
def get_bazel_path(bazel_path_flag):
"""Returns the path to a Bazel binary, downloading Bazel if not found.
Also, checks Bazel's version is at least newer than 6.5.0
A manual version check is needed only for really old bazel versions.
Newer bazel releases perform their own version check against .bazelversion
(see for details
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
"""
for path in filter(None, get_bazel_paths(bazel_path_flag)):
version = get_bazel_version(path)
if version is not None and version >= (6, 5, 0):
return path, ".".join(map(str, version))
print(
"Cannot find or download a suitable version of bazel."
"Please install bazel >= 6.5.0."
)
sys.exit(-1)
def get_bazel_version(bazel_path):
try:
version_output = shell([bazel_path, "--version"])
except (subprocess.CalledProcessError, OSError):
return None
match = re.search(r"bazel *([0-9\\.]+)", version_output)
if match is None:
return None
return tuple(int(x) for x in match.group(1).split("."))
def get_clang_path_or_exit():
which_clang_output = shutil.which("clang")
if which_clang_output:
# If we've found a clang on the path, need to get the fully resolved path
# to ensure that system headers are found.
return str(pathlib.Path(which_clang_output).resolve())
else:
print(
"--clang_path is unset and clang cannot be found"
" on the PATH. Please pass --clang_path directly."
)
sys.exit(-1)
def get_clang_major_version(clang_path):
clang_version_proc = subprocess.run(
[clang_path, "-E", "-P", "-"],
input="__clang_major__",
check=True,
capture_output=True,
text=True,
)
major_version = int(clang_version_proc.stdout)
return major_version
# Python
def get_python_bin_path(python_bin_path_flag):
"""Returns the path to the Python interpreter to use."""
path = python_bin_path_flag or sys.executable
return path.replace(os.sep, "/")
def get_python_version(python_bin_path):
version_output = shell([
python_bin_path,
"-c",
(
'import sys; print("{}.{}".format(sys.version_info[0], '
"sys.version_info[1]))"
),
])
major, minor = map(int, version_output.split("."))
return major, minor
def check_python_version(python_version):
if python_version < (3, 10):
print("ERROR: JAX requires Python 3.10 or newer, found ", python_version)
sys.exit(-1)
def get_githash():
try:
return subprocess.run(
["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True
).stdout.strip()
except OSError:
return ""