Fix warning about direct invocation of setup.py during jaxlib build.

The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
This commit is contained in:
Peter Hawkins 2023-07-14 08:30:41 -07:00 committed by jax authors
parent 02e43e3510
commit f540ae4338
6 changed files with 13 additions and 13 deletions

View File

@ -537,7 +537,7 @@ def main():
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
[":build_wheel", "--",
["//jaxlib/tools:build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
if args.editable:

View File

@ -101,6 +101,6 @@ ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
RUN pyenv install $PYTHON_VERSION
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip==22.0 && pip install numpy==1.21.0 setuptools wheel six auditwheel scipy pytest pytest-rerunfailures matplotlib absl-py
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip==22.0 && pip install numpy==1.21.0 setuptools build wheel six auditwheel scipy pytest pytest-rerunfailures matplotlib absl-py

View File

@ -1,4 +1,5 @@
absl-py
build
cloudpickle
colorama>=0.4.4
numpy>=1.22

View File

@ -43,12 +43,12 @@ To build `jaxlib` from source, you must also install some prerequisites:
See below for Windows build instructions.
- Python packages: `numpy`, `wheel`.
- Python packages: `numpy`, `wheel`, `build`.
You can install the necessary Python dependencies using `pip`:
```
pip install numpy wheel
pip install numpy wheel build
```
To build `jaxlib` without CUDA GPU or TPU support (CPU only), you can run:

View File

@ -153,7 +153,7 @@ def prepare_wheel(sources_path):
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
verify_mac_libraries_dont_reference_chkstack()
copy_file("__main__/build/LICENSE.txt", dst_dir=sources_path)
copy_file("__main__/jaxlib/tools/LICENSE.txt", dst_dir=sources_path)
copy_file("__main__/jaxlib/README.md", dst_dir=sources_path)
copy_file("__main__/jaxlib/setup.py", dst_dir=sources_path)
copy_file("__main__/jaxlib/setup.cfg", dst_dir=sources_path)
@ -273,16 +273,15 @@ def build_wheel(sources_path, output_path, cpu):
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}[(platform.system(), cpu)]
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
python_tag_arg = (f"-C=--build-option=--python-tag=cp{sys.version_info.major}"
f"{sys.version_info.minor}")
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
cwd = os.getcwd()
platform_tag_arg = f"-C=--build-option=--plat-name={platform_name}_{cpu_name}"
if os.environ.get('JAXLIB_NIGHTLY'):
edit_jaxlib_version(sources_path)
os.chdir(sources_path)
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
python_tag_arg, platform_tag_arg], check=True)
os.chdir(cwd)
subprocess.run(
[sys.executable, "-m", "build", "-n", "-w",
python_tag_arg, platform_tag_arg],
check=True, cwd=sources_path)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")