Check build and wheel are installed before building jaxlib.

This commit is contained in:
Peter Hawkins 2023-07-26 11:45:16 -07:00
parent 1054fe5a3b
commit 3c4527b6b0

View File

@ -52,7 +52,7 @@ def shell(cmd):
try:
output = subprocess.check_output(cmd)
except subprocess.CalledProcessError as e:
print(e.output)
if e.output: print(e.output)
raise
return output.decode("UTF-8").strip()
@ -78,6 +78,12 @@ def check_python_version(python_version):
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
sys.exit(-1)
def check_package_is_installed(python_bin_path, package):
try:
shell([python_bin_path, "-c", f"import {package}"])
except:
print(f"ERROR: jaxlib build requires package '{package}' to be installed.")
sys.exit(-1)
def check_numpy_version(python_bin_path):
version = shell(
@ -478,6 +484,8 @@ def main():
numpy_version = check_numpy_version(python_bin_path)
print(f"NumPy version: {numpy_version}")
check_package_is_installed(python_bin_path, "wheel")
check_package_is_installed(python_bin_path, "build")
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
print(f"Target CPU: {wheel_cpu}")