mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead. To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`. PiperOrigin-RevId: 548133811
This commit is contained in:
parent
02e43e3510
commit
f540ae4338
@ -537,7 +537,7 @@ def main():
|
||||
|
||||
command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] +
|
||||
[":build_wheel", "--",
|
||||
["//jaxlib/tools:build_wheel", "--",
|
||||
f"--output_path={output_path}",
|
||||
f"--cpu={wheel_cpu}"])
|
||||
if args.editable:
|
||||
|
@ -101,6 +101,6 @@ ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
|
||||
|
||||
RUN pyenv install $PYTHON_VERSION
|
||||
|
||||
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip==22.0 && pip install numpy==1.21.0 setuptools wheel six auditwheel scipy pytest pytest-rerunfailures matplotlib absl-py
|
||||
|
||||
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip==22.0 && pip install numpy==1.21.0 setuptools build wheel six auditwheel scipy pytest pytest-rerunfailures matplotlib absl-py
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
absl-py
|
||||
build
|
||||
cloudpickle
|
||||
colorama>=0.4.4
|
||||
numpy>=1.22
|
||||
|
@ -43,12 +43,12 @@ To build `jaxlib` from source, you must also install some prerequisites:
|
||||
|
||||
See below for Windows build instructions.
|
||||
|
||||
- Python packages: `numpy`, `wheel`.
|
||||
- Python packages: `numpy`, `wheel`, `build`.
|
||||
|
||||
You can install the necessary Python dependencies using `pip`:
|
||||
|
||||
```
|
||||
pip install numpy wheel
|
||||
pip install numpy wheel build
|
||||
```
|
||||
|
||||
To build `jaxlib` without CUDA GPU or TPU support (CPU only), you can run:
|
||||
|
@ -153,7 +153,7 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
|
||||
|
||||
verify_mac_libraries_dont_reference_chkstack()
|
||||
copy_file("__main__/build/LICENSE.txt", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/README.md", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path)
|
||||
copy_file("__main__/jaxlib/setup.cfg", dst_dir=sources_path)
|
||||
@ -273,16 +273,15 @@ def build_wheel(sources_path, output_path, cpu):
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}[(platform.system(), cpu)]
|
||||
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
|
||||
python_tag_arg = (f"-C=--build-option=--python-tag=cp{sys.version_info.major}"
|
||||
f"{sys.version_info.minor}")
|
||||
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
|
||||
cwd = os.getcwd()
|
||||
platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
|
||||
if os.environ.get('JAXLIB_NIGHTLY'):
|
||||
edit_jaxlib_version(sources_path)
|
||||
os.chdir(sources_path)
|
||||
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
|
||||
python_tag_arg, platform_tag_arg], check=True)
|
||||
os.chdir(cwd)
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "build", "-n", "-w",
|
||||
python_tag_arg, platform_tag_arg],
|
||||
check=True, cwd=sources_path)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
Loading…
x
Reference in New Issue
Block a user