mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Check build
and wheel
are installed before building jaxlib
.
This commit is contained in:
parent
1054fe5a3b
commit
3c4527b6b0
@ -52,7 +52,7 @@ def shell(cmd):
|
||||
try:
|
||||
output = subprocess.check_output(cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
if e.output: print(e.output)
|
||||
raise
|
||||
return output.decode("UTF-8").strip()
|
||||
|
||||
@ -78,6 +78,12 @@ def check_python_version(python_version):
|
||||
print("ERROR: JAX requires Python 3.9 or newer, found ", python_version)
|
||||
sys.exit(-1)
|
||||
|
||||
def check_package_is_installed(python_bin_path, package):
|
||||
try:
|
||||
shell([python_bin_path, "-c", f"import {package}"])
|
||||
except:
|
||||
print(f"ERROR: jaxlib build requires package '{package}' to be installed.")
|
||||
sys.exit(-1)
|
||||
|
||||
def check_numpy_version(python_bin_path):
|
||||
version = shell(
|
||||
@ -478,6 +484,8 @@ def main():
|
||||
|
||||
numpy_version = check_numpy_version(python_bin_path)
|
||||
print(f"NumPy version: {numpy_version}")
|
||||
check_package_is_installed(python_bin_path, "wheel")
|
||||
check_package_is_installed(python_bin_path, "build")
|
||||
|
||||
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
||||
print(f"Target CPU: {wheel_cpu}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user