mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
255 lines
7.9 KiB
Python
255 lines
7.9 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
# Helper script for tools/utilities used by the JAX build CLI.
|
|
import collections
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import platform
|
|
import re
|
|
import shutil
|
|
import stat
|
|
import subprocess
|
|
import sys
|
|
import urllib.request
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/7.4.1/"
|
|
BazelPackage = collections.namedtuple(
|
|
"BazelPackage", ["base_uri", "file", "sha256"]
|
|
)
|
|
bazel_packages = {
|
|
("Linux", "x86_64"): BazelPackage(
|
|
base_uri=None,
|
|
file="bazel-7.4.1-linux-x86_64",
|
|
sha256=(
|
|
"c97f02133adce63f0c28678ac1f21d65fa8255c80429b588aeeba8a1fac6202b"
|
|
),
|
|
),
|
|
("Linux", "aarch64"): BazelPackage(
|
|
base_uri=None,
|
|
file="bazel-7.4.1-linux-arm64",
|
|
sha256=(
|
|
"d7aedc8565ed47b6231badb80b09f034e389c5f2b1c2ac2c55406f7c661d8b88"
|
|
),
|
|
),
|
|
("Darwin", "x86_64"): BazelPackage(
|
|
base_uri=None,
|
|
file="bazel-7.4.1-darwin-x86_64",
|
|
sha256=(
|
|
"52dd34c17cc97b3aa5bdfe3d45c4e3938226f23dd0bfb47beedd625a953f1f05"
|
|
),
|
|
),
|
|
("Darwin", "arm64"): BazelPackage(
|
|
base_uri=None,
|
|
file="bazel-7.4.1-darwin-arm64",
|
|
sha256=(
|
|
"02b117b97d0921ae4d4f4e11d27e2c0930381df416e373435d5d0419c6a26f24"
|
|
),
|
|
),
|
|
("Windows", "AMD64"): BazelPackage(
|
|
base_uri=None,
|
|
file="bazel-7.4.1-windows-x86_64.exe",
|
|
sha256=(
|
|
"4a76eddf6c5115e1d93355fd11db5ac2fc20e58f197f5d65d3f21da92aa0925b"
|
|
),
|
|
),
|
|
}
|
|
|
|
def download_and_verify_bazel():
|
|
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
|
|
package = bazel_packages.get((platform.system(), platform.machine()))
|
|
if package is None:
|
|
return None
|
|
|
|
if not os.access(package.file, os.X_OK):
|
|
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
|
|
sys.stdout.write(f"Downloading bazel from: {uri}\n")
|
|
|
|
def progress(block_count, block_size, total_size):
|
|
if total_size <= 0:
|
|
total_size = 170**6
|
|
progress = (block_count * block_size) / total_size
|
|
num_chars = 40
|
|
progress_chars = int(num_chars * progress)
|
|
sys.stdout.write(
|
|
"{} [{}{}] {}%\r".format(
|
|
package.file,
|
|
"#" * progress_chars,
|
|
"." * (num_chars - progress_chars),
|
|
int(progress * 100.0),
|
|
)
|
|
)
|
|
|
|
tmp_path, _ = urllib.request.urlretrieve(
|
|
uri, None, progress if sys.stdout.isatty() else None
|
|
)
|
|
sys.stdout.write("\n")
|
|
|
|
# Verify that the downloaded Bazel binary has the expected SHA256.
|
|
with open(tmp_path, "rb") as downloaded_file:
|
|
contents = downloaded_file.read()
|
|
|
|
digest = hashlib.sha256(contents).hexdigest()
|
|
if digest != package.sha256:
|
|
print(
|
|
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
|
.format(package.sha256, digest)
|
|
)
|
|
sys.exit(-1)
|
|
|
|
# Write the file as the bazel file name.
|
|
with open(package.file, "wb") as out_file:
|
|
out_file.write(contents)
|
|
|
|
# Mark the file as executable.
|
|
st = os.stat(package.file)
|
|
os.chmod(
|
|
package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
|
|
)
|
|
|
|
return os.path.join(".", package.file)
|
|
|
|
def get_bazel_paths(bazel_path_flag):
|
|
"""Yields a sequence of guesses about bazel path.
|
|
|
|
Some of sequence elements can be None. The resulting iterator is lazy and
|
|
potentially has a side effects.
|
|
"""
|
|
yield bazel_path_flag
|
|
yield shutil.which("bazel")
|
|
yield download_and_verify_bazel()
|
|
|
|
def get_bazel_path(bazel_path_flag):
|
|
"""Returns the path to a Bazel binary, downloading Bazel if not found.
|
|
|
|
Also, checks Bazel's version is at least newer than 7.4.1
|
|
|
|
A manual version check is needed only for really old bazel versions.
|
|
Newer bazel releases perform their own version check against .bazelversion
|
|
(see for details
|
|
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
|
|
"""
|
|
for path in filter(None, get_bazel_paths(bazel_path_flag)):
|
|
version = get_bazel_version(path)
|
|
if version is not None and version >= (6, 5, 0):
|
|
return path, ".".join(map(str, version))
|
|
|
|
print(
|
|
"Cannot find or download a suitable version of bazel."
|
|
"Please install bazel >= 7.4.1."
|
|
)
|
|
sys.exit(-1)
|
|
|
|
def get_bazel_version(bazel_path):
|
|
try:
|
|
version_output = subprocess.run(
|
|
[bazel_path, "--version"],
|
|
encoding="utf-8",
|
|
capture_output=True,
|
|
check=True,
|
|
).stdout.strip()
|
|
except (subprocess.CalledProcessError, OSError):
|
|
return None
|
|
match = re.search(r"bazel *([0-9\\.]+)", version_output)
|
|
if match is None:
|
|
return None
|
|
return tuple(int(x) for x in match.group(1).split("."))
|
|
|
|
def get_compiler_path_or_exit(compiler_path_flag, compiler_name):
|
|
which_compiler_output = shutil.which(compiler_name)
|
|
if which_compiler_output:
|
|
# If we've found a compiler on the path, need to get the fully resolved path
|
|
# to ensure that system headers are found.
|
|
return str(pathlib.Path(which_compiler_output).resolve())
|
|
else:
|
|
print(
|
|
f"--{compiler_path_flag} is unset and {compiler_name} cannot be found"
|
|
f" on the PATH. Please pass --{compiler_path_flag} to the build script."
|
|
)
|
|
sys.exit(-1)
|
|
|
|
def get_gcc_path_or_exit():
|
|
return get_compiler_path_or_exit("gcc_path", "gcc")
|
|
|
|
def get_clang_path_or_exit():
|
|
return get_compiler_path_or_exit("clang_path", "clang")
|
|
|
|
def get_clang_major_version(clang_path):
|
|
clang_version_proc = subprocess.run(
|
|
[clang_path, "-E", "-P", "-"],
|
|
input="__clang_major__",
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
major_version = int(clang_version_proc.stdout)
|
|
|
|
return major_version
|
|
|
|
def get_gcc_major_version(gcc_path: str):
|
|
gcc_version_proc = subprocess.run(
|
|
[gcc_path, "-dumpversion"],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
major_version = int(gcc_version_proc.stdout.split(".")[0])
|
|
|
|
return major_version
|
|
|
|
|
|
def get_jax_configure_bazel_options(bazel_command: list[str]):
|
|
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
|
|
# Get the index of the "run" parameter. Build options will come after "run" so
|
|
# we find the index of "run" and filter everything after it.
|
|
start = bazel_command.index("run")
|
|
jax_configure_bazel_options = ""
|
|
try:
|
|
for i in range(start + 1, len(bazel_command)):
|
|
bazel_flag = bazel_command[i]
|
|
# On Windows, replace all backslashes with double backslashes to avoid
|
|
# unintended escape sequences.
|
|
if platform.system() == "Windows":
|
|
bazel_flag = bazel_flag.replace("\\", "\\\\")
|
|
jax_configure_bazel_options += f"build {bazel_flag}\n"
|
|
return jax_configure_bazel_options
|
|
except ValueError:
|
|
logging.error("Unable to find index for 'run' in the Bazel command")
|
|
return ""
|
|
|
|
def get_githash():
|
|
try:
|
|
return subprocess.run(
|
|
["git", "rev-parse", "HEAD"],
|
|
encoding="utf-8",
|
|
capture_output=True,
|
|
check=True,
|
|
).stdout.strip()
|
|
except (subprocess.CalledProcessError, OSError):
|
|
return ""
|
|
|
|
def _parse_string_as_bool(s):
|
|
"""Parses a string as a boolean value."""
|
|
lower = s.lower()
|
|
if lower == "true":
|
|
return True
|
|
elif lower == "false":
|
|
return False
|
|
else:
|
|
raise ValueError(f"Expected either 'true' or 'false'; got {s}")
|