mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability. PiperOrigin-RevId: 691458051
This commit is contained in:
parent
d2f5804449
commit
da994d3552
274
build/build.py
274
build/build.py
@ -16,226 +16,17 @@
|
||||
#
|
||||
# Helper script for building JAX's libjax easily.
|
||||
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import urllib.request
|
||||
|
||||
from tools import utils
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
def shell(cmd):
|
||||
try:
|
||||
logger.info("shell(): %s", cmd)
|
||||
output = subprocess.check_output(cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.info("subprocess raised: %s", e)
|
||||
if e.output: print(e.output)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.info("subprocess raised: %s", e)
|
||||
raise
|
||||
return output.decode("UTF-8").strip()
|
||||
|
||||
|
||||
# Python
|
||||
|
||||
def get_python_bin_path(python_bin_path_flag):
|
||||
"""Returns the path to the Python interpreter to use."""
|
||||
path = python_bin_path_flag or sys.executable
|
||||
return path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def get_python_version(python_bin_path):
|
||||
version_output = shell(
|
||||
[python_bin_path, "-c",
|
||||
("import sys; print(\"{}.{}\".format(sys.version_info[0], "
|
||||
"sys.version_info[1]))")])
|
||||
major, minor = map(int, version_output.split("."))
|
||||
return major, minor
|
||||
|
||||
def check_python_version(python_version):
|
||||
if python_version < (3, 10):
|
||||
print("ERROR: JAX requires Python 3.10 or newer, found ", python_version)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def get_githash():
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
encoding='utf-8',
|
||||
capture_output=True).stdout.strip()
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
# Bazel
|
||||
|
||||
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
|
||||
BazelPackage = collections.namedtuple("BazelPackage",
|
||||
["base_uri", "file", "sha256"])
|
||||
bazel_packages = {
|
||||
("Linux", "x86_64"):
|
||||
BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-linux-x86_64",
|
||||
sha256=
|
||||
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"),
|
||||
("Linux", "aarch64"):
|
||||
BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-linux-arm64",
|
||||
sha256=
|
||||
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"),
|
||||
("Darwin", "x86_64"):
|
||||
BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-darwin-x86_64",
|
||||
sha256=
|
||||
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"),
|
||||
("Darwin", "arm64"):
|
||||
BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-darwin-arm64",
|
||||
sha256=
|
||||
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"),
|
||||
("Windows", "AMD64"):
|
||||
BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-windows-x86_64.exe",
|
||||
sha256=
|
||||
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"),
|
||||
}
|
||||
|
||||
|
||||
def download_and_verify_bazel():
|
||||
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
|
||||
package = bazel_packages.get((platform.system(), platform.machine()))
|
||||
if package is None:
|
||||
return None
|
||||
|
||||
if not os.access(package.file, os.X_OK):
|
||||
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
|
||||
sys.stdout.write(f"Downloading bazel from: {uri}\n")
|
||||
|
||||
def progress(block_count, block_size, total_size):
|
||||
if total_size <= 0:
|
||||
total_size = 170**6
|
||||
progress = (block_count * block_size) / total_size
|
||||
num_chars = 40
|
||||
progress_chars = int(num_chars * progress)
|
||||
sys.stdout.write("{} [{}{}] {}%\r".format(
|
||||
package.file, "#" * progress_chars,
|
||||
"." * (num_chars - progress_chars), int(progress * 100.0)))
|
||||
|
||||
tmp_path, _ = urllib.request.urlretrieve(
|
||||
uri, None, progress if sys.stdout.isatty() else None
|
||||
)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
# Verify that the downloaded Bazel binary has the expected SHA256.
|
||||
with open(tmp_path, "rb") as downloaded_file:
|
||||
contents = downloaded_file.read()
|
||||
|
||||
digest = hashlib.sha256(contents).hexdigest()
|
||||
if digest != package.sha256:
|
||||
print(
|
||||
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
||||
.format(package.sha256, digest))
|
||||
sys.exit(-1)
|
||||
|
||||
# Write the file as the bazel file name.
|
||||
with open(package.file, "wb") as out_file:
|
||||
out_file.write(contents)
|
||||
|
||||
# Mark the file as executable.
|
||||
st = os.stat(package.file)
|
||||
os.chmod(package.file,
|
||||
st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
|
||||
|
||||
return os.path.join(".", package.file)
|
||||
|
||||
|
||||
def get_bazel_paths(bazel_path_flag):
|
||||
"""Yields a sequence of guesses about bazel path. Some of sequence elements
|
||||
can be None. The resulting iterator is lazy and potentially has a side
|
||||
effects."""
|
||||
yield bazel_path_flag
|
||||
yield shutil.which("bazel")
|
||||
yield download_and_verify_bazel()
|
||||
|
||||
|
||||
def get_bazel_path(bazel_path_flag):
|
||||
"""Returns the path to a Bazel binary, downloading Bazel if not found. Also,
|
||||
checks Bazel's version is at least newer than 6.5.0
|
||||
|
||||
A manual version check is needed only for really old bazel versions.
|
||||
Newer bazel releases perform their own version check against .bazelversion
|
||||
(see for details
|
||||
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
|
||||
"""
|
||||
for path in filter(None, get_bazel_paths(bazel_path_flag)):
|
||||
version = get_bazel_version(path)
|
||||
if version is not None and version >= (6, 5, 0):
|
||||
return path, ".".join(map(str, version))
|
||||
|
||||
print("Cannot find or download a suitable version of bazel."
|
||||
"Please install bazel >= 6.5.0.")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def get_bazel_version(bazel_path):
|
||||
try:
|
||||
version_output = shell([bazel_path, "--version"])
|
||||
except (subprocess.CalledProcessError, OSError):
|
||||
return None
|
||||
match = re.search(r"bazel *([0-9\\.]+)", version_output)
|
||||
if match is None:
|
||||
return None
|
||||
return tuple(int(x) for x in match.group(1).split("."))
|
||||
|
||||
|
||||
def get_clang_path_or_exit():
|
||||
which_clang_output = shutil.which("clang")
|
||||
if which_clang_output:
|
||||
# If we've found a clang on the path, need to get the fully resolved path
|
||||
# to ensure that system headers are found.
|
||||
return str(pathlib.Path(which_clang_output).resolve())
|
||||
else:
|
||||
print(
|
||||
"--clang_path is unset and clang cannot be found"
|
||||
" on the PATH. Please pass --clang_path directly."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
def get_clang_major_version(clang_path):
|
||||
clang_version_proc = subprocess.run(
|
||||
[clang_path, "-E", "-P", "-"],
|
||||
input="__clang_major__",
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
major_version = int(clang_version_proc.stdout)
|
||||
|
||||
return major_version
|
||||
|
||||
|
||||
def write_bazelrc(*, remote_build,
|
||||
cuda_version, cudnn_version, rocm_toolkit_path,
|
||||
cpu, cuda_compute_capabilities,
|
||||
@ -272,10 +63,10 @@ def write_bazelrc(*, remote_build,
|
||||
|
||||
if target_cpu_features == "release":
|
||||
if wheel_cpu == "x86_64":
|
||||
f.write("build --config=avx_windows\n" if is_windows()
|
||||
f.write("build --config=avx_windows\n" if utils.is_windows()
|
||||
else "build --config=avx_posix\n")
|
||||
elif target_cpu_features == "native":
|
||||
if is_windows():
|
||||
if utils.is_windows():
|
||||
print("--target_cpu_features=native is not supported on Windows; ignoring.")
|
||||
else:
|
||||
f.write("build --config=native_arch_posix\n")
|
||||
@ -575,18 +366,18 @@ def main():
|
||||
else host_cpu)
|
||||
|
||||
# Find a working Bazel.
|
||||
bazel_path, bazel_version = get_bazel_path(args.bazel_path)
|
||||
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
|
||||
print(f"Bazel binary path: {bazel_path}")
|
||||
print(f"Bazel version: {bazel_version}")
|
||||
|
||||
if args.python_version:
|
||||
python_version = args.python_version
|
||||
else:
|
||||
python_bin_path = get_python_bin_path(args.python_bin_path)
|
||||
python_bin_path = utils.get_python_bin_path(args.python_bin_path)
|
||||
print(f"Python binary path: {python_bin_path}")
|
||||
python_version = get_python_version(python_bin_path)
|
||||
python_version = utils.get_python_version(python_bin_path)
|
||||
print("Python version: {}".format(".".join(map(str, python_version))))
|
||||
check_python_version(python_version)
|
||||
utils.check_python_version(python_version)
|
||||
python_version = ".".join(map(str, python_version))
|
||||
|
||||
print("Use clang: {}".format("yes" if args.use_clang else "no"))
|
||||
@ -594,9 +385,9 @@ def main():
|
||||
clang_major_version = None
|
||||
if args.use_clang:
|
||||
if not clang_path:
|
||||
clang_path = get_clang_path_or_exit()
|
||||
clang_path = utils.get_clang_path_or_exit()
|
||||
print(f"clang path: {clang_path}")
|
||||
clang_major_version = get_clang_major_version(clang_path)
|
||||
clang_major_version = utils.get_clang_major_version(clang_path)
|
||||
|
||||
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
||||
print(f"Target CPU: {wheel_cpu}")
|
||||
@ -648,7 +439,7 @@ def main():
|
||||
update_command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true", task, *args.bazel_options])
|
||||
print(" ".join(update_command))
|
||||
shell(update_command)
|
||||
utils.shell(update_command)
|
||||
return
|
||||
|
||||
if args.configure_only:
|
||||
@ -675,27 +466,29 @@ def main():
|
||||
|
||||
if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin:
|
||||
build_cpu_wheel_command = [
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_wheel", "--",
|
||||
f"--output_path={output_path_jaxlib}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}"
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_wheel",
|
||||
"--",
|
||||
f"--output_path={output_path_jaxlib}",
|
||||
f"--jaxlib_git_hash={utils.get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
]
|
||||
if args.build_gpu_plugin:
|
||||
build_cpu_wheel_command.append("--skip_gpu_kernels")
|
||||
if args.editable:
|
||||
build_cpu_wheel_command.append("--editable")
|
||||
print(" ".join(build_cpu_wheel_command))
|
||||
shell(build_cpu_wheel_command)
|
||||
utils.shell(build_cpu_wheel_command)
|
||||
|
||||
if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \
|
||||
(args.build_gpu_kernel_plugin == "rocm"):
|
||||
build_gpu_kernels_command = [
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_gpu_kernels_wheel", "--",
|
||||
f"--output_path={output_path_jax_kernel}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_gpu_kernels_wheel",
|
||||
"--",
|
||||
f"--output_path={output_path_jax_kernel}",
|
||||
f"--jaxlib_git_hash={utils.get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
]
|
||||
if args.enable_cuda:
|
||||
build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}")
|
||||
@ -708,15 +501,16 @@ def main():
|
||||
if args.editable:
|
||||
build_gpu_kernels_command.append("--editable")
|
||||
print(" ".join(build_gpu_kernels_command))
|
||||
shell(build_gpu_kernels_command)
|
||||
utils.shell(build_gpu_kernels_command)
|
||||
|
||||
if args.build_gpu_plugin or args.build_gpu_pjrt_plugin:
|
||||
build_pjrt_plugin_command = [
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_gpu_plugin_wheel", "--",
|
||||
f"--output_path={output_path_jax_pjrt}",
|
||||
f"--jaxlib_git_hash={get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
*command_base,
|
||||
"//jaxlib/tools:build_gpu_plugin_wheel",
|
||||
"--",
|
||||
f"--output_path={output_path_jax_pjrt}",
|
||||
f"--jaxlib_git_hash={utils.get_githash()}",
|
||||
f"--cpu={wheel_cpu}",
|
||||
]
|
||||
if args.enable_cuda:
|
||||
build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}")
|
||||
@ -729,9 +523,9 @@ def main():
|
||||
if args.editable:
|
||||
build_pjrt_plugin_command.append("--editable")
|
||||
print(" ".join(build_pjrt_plugin_command))
|
||||
shell(build_pjrt_plugin_command)
|
||||
utils.shell(build_pjrt_plugin_command)
|
||||
|
||||
shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
|
||||
utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
249
build/tools/utils.py
Normal file
249
build/tools/utils.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Helper script for tools/utilities used by the JAX build CLI.
|
||||
import collections
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
def shell(cmd):
|
||||
try:
|
||||
logger.info("shell(): %s", cmd)
|
||||
output = subprocess.check_output(cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.info("subprocess raised: %s", e)
|
||||
if e.output:
|
||||
print(e.output)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.info("subprocess raised: %s", e)
|
||||
raise
|
||||
return output.decode("UTF-8").strip()
|
||||
|
||||
|
||||
# Bazel
|
||||
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
|
||||
BazelPackage = collections.namedtuple(
|
||||
"BazelPackage", ["base_uri", "file", "sha256"]
|
||||
)
|
||||
bazel_packages = {
|
||||
("Linux", "x86_64"): BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-linux-x86_64",
|
||||
sha256=(
|
||||
"a40ac69263440761199fcb8da47ad4e3f328cbe79ffbf4ecc14e5ba252857307"
|
||||
),
|
||||
),
|
||||
("Linux", "aarch64"): BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-linux-arm64",
|
||||
sha256=(
|
||||
"5afe973cadc036496cac66f1414ca9be36881423f576db363d83afc9084c0c2f"
|
||||
),
|
||||
),
|
||||
("Darwin", "x86_64"): BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-darwin-x86_64",
|
||||
sha256=(
|
||||
"bbf9c2c03bac48e0514f46db0295027935535d91f6d8dcd960c53393559eab29"
|
||||
),
|
||||
),
|
||||
("Darwin", "arm64"): BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-darwin-arm64",
|
||||
sha256=(
|
||||
"c6b6dc17efcdf13fba484c6fe0b6c3361b888ae7b9573bc25a2dbe8c502448eb"
|
||||
),
|
||||
),
|
||||
("Windows", "AMD64"): BazelPackage(
|
||||
base_uri=None,
|
||||
file="bazel-6.5.0-windows-x86_64.exe",
|
||||
sha256=(
|
||||
"6eae8e7f28e1b68b833503d1a58caf139c11e52de19df0d787d974653a0ea4c6"
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def download_and_verify_bazel():
|
||||
"""Downloads a bazel binary from GitHub, verifying its SHA256 hash."""
|
||||
package = bazel_packages.get((platform.system(), platform.machine()))
|
||||
if package is None:
|
||||
return None
|
||||
|
||||
if not os.access(package.file, os.X_OK):
|
||||
uri = (package.base_uri or BAZEL_BASE_URI) + package.file
|
||||
sys.stdout.write(f"Downloading bazel from: {uri}\n")
|
||||
|
||||
def progress(block_count, block_size, total_size):
|
||||
if total_size <= 0:
|
||||
total_size = 170**6
|
||||
progress = (block_count * block_size) / total_size
|
||||
num_chars = 40
|
||||
progress_chars = int(num_chars * progress)
|
||||
sys.stdout.write(
|
||||
"{} [{}{}] {}%\r".format(
|
||||
package.file,
|
||||
"#" * progress_chars,
|
||||
"." * (num_chars - progress_chars),
|
||||
int(progress * 100.0),
|
||||
)
|
||||
)
|
||||
|
||||
tmp_path, _ = urllib.request.urlretrieve(
|
||||
uri, None, progress if sys.stdout.isatty() else None
|
||||
)
|
||||
sys.stdout.write("\n")
|
||||
|
||||
# Verify that the downloaded Bazel binary has the expected SHA256.
|
||||
with open(tmp_path, "rb") as downloaded_file:
|
||||
contents = downloaded_file.read()
|
||||
|
||||
digest = hashlib.sha256(contents).hexdigest()
|
||||
if digest != package.sha256:
|
||||
print(
|
||||
"Checksum mismatch for downloaded bazel binary (expected {}; got {})."
|
||||
.format(package.sha256, digest)
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
# Write the file as the bazel file name.
|
||||
with open(package.file, "wb") as out_file:
|
||||
out_file.write(contents)
|
||||
|
||||
# Mark the file as executable.
|
||||
st = os.stat(package.file)
|
||||
os.chmod(
|
||||
package.file, st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
|
||||
)
|
||||
|
||||
return os.path.join(".", package.file)
|
||||
|
||||
|
||||
def get_bazel_paths(bazel_path_flag):
|
||||
"""Yields a sequence of guesses about bazel path.
|
||||
|
||||
Some of sequence elements can be None. The resulting iterator is lazy and
|
||||
potentially has a side effects.
|
||||
"""
|
||||
yield bazel_path_flag
|
||||
yield shutil.which("bazel")
|
||||
yield download_and_verify_bazel()
|
||||
|
||||
|
||||
def get_bazel_path(bazel_path_flag):
|
||||
"""Returns the path to a Bazel binary, downloading Bazel if not found.
|
||||
|
||||
Also, checks Bazel's version is at least newer than 6.5.0
|
||||
|
||||
A manual version check is needed only for really old bazel versions.
|
||||
Newer bazel releases perform their own version check against .bazelversion
|
||||
(see for details
|
||||
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
|
||||
"""
|
||||
for path in filter(None, get_bazel_paths(bazel_path_flag)):
|
||||
version = get_bazel_version(path)
|
||||
if version is not None and version >= (6, 5, 0):
|
||||
return path, ".".join(map(str, version))
|
||||
|
||||
print(
|
||||
"Cannot find or download a suitable version of bazel."
|
||||
"Please install bazel >= 6.5.0."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def get_bazel_version(bazel_path):
|
||||
try:
|
||||
version_output = shell([bazel_path, "--version"])
|
||||
except (subprocess.CalledProcessError, OSError):
|
||||
return None
|
||||
match = re.search(r"bazel *([0-9\\.]+)", version_output)
|
||||
if match is None:
|
||||
return None
|
||||
return tuple(int(x) for x in match.group(1).split("."))
|
||||
|
||||
|
||||
def get_clang_path_or_exit():
|
||||
which_clang_output = shutil.which("clang")
|
||||
if which_clang_output:
|
||||
# If we've found a clang on the path, need to get the fully resolved path
|
||||
# to ensure that system headers are found.
|
||||
return str(pathlib.Path(which_clang_output).resolve())
|
||||
else:
|
||||
print(
|
||||
"--clang_path is unset and clang cannot be found"
|
||||
" on the PATH. Please pass --clang_path directly."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def get_clang_major_version(clang_path):
|
||||
clang_version_proc = subprocess.run(
|
||||
[clang_path, "-E", "-P", "-"],
|
||||
input="__clang_major__",
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
major_version = int(clang_version_proc.stdout)
|
||||
|
||||
return major_version
|
||||
|
||||
|
||||
# Python
|
||||
def get_python_bin_path(python_bin_path_flag):
|
||||
"""Returns the path to the Python interpreter to use."""
|
||||
path = python_bin_path_flag or sys.executable
|
||||
return path.replace(os.sep, "/")
|
||||
|
||||
|
||||
def get_python_version(python_bin_path):
|
||||
version_output = shell([
|
||||
python_bin_path,
|
||||
"-c",
|
||||
(
|
||||
'import sys; print("{}.{}".format(sys.version_info[0], '
|
||||
"sys.version_info[1]))"
|
||||
),
|
||||
])
|
||||
major, minor = map(int, version_output.split("."))
|
||||
return major, minor
|
||||
|
||||
def check_python_version(python_version):
|
||||
if python_version < (3, 10):
|
||||
print("ERROR: JAX requires Python 3.10 or newer, found ", python_version)
|
||||
sys.exit(-1)
|
||||
|
||||
def get_githash():
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True
|
||||
).stdout.strip()
|
||||
except OSError:
|
||||
return ""
|
Loading…
x
Reference in New Issue
Block a user