Change jaxlib build rules to build a wheel, rather than writing output to the source directory.

This commit is contained in:
Peter Hawkins 2020-11-20 09:10:02 -05:00
parent 5a41779fbe
commit c06ead6b04
11 changed files with 225 additions and 164 deletions

6
.gitignore vendored
View File

@ -1,15 +1,17 @@
*.pyc
*.so
*.egg-info
*.whl
build/bazel*
dist/
.ipynb_checkpoints
/bazel-*
.bazelrc
/tensorflow
.DS_Store
build/
dist/
.mypy_cache/
.pytype/
docs/build
docs/notebooks/.ipynb_checkpoints/
docs/_autosummary
.idea

View File

@ -22,8 +22,8 @@ licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
py_binary(
name = "install_xla_in_source_tree",
srcs = ["install_xla_in_source_tree.py"],
name = "build_wheel",
srcs = ["build_wheel.py"],
data = [
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"//jaxlib",

View File

@ -14,10 +14,10 @@ RUN /pyenv/bin/pyenv install 3.8.0
RUN /pyenv/bin/pyenv install 3.9.0
# We pin numpy to a version < 1.16 to avoid version compatibility issues.
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 cython setuptools wheel packaging six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy==1.5.4 setuptools wheel six auditwheel
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 setuptools wheel six auditwheel
# Change the CUDA version if it doesn't match the installed version.
ARG JAX_CUDA_VERSION=10.0

View File

@ -333,8 +333,9 @@ def add_boolean_argument(parser, name, default=False, help_str=None):
def main():
cwd = os.getcwd()
parser = argparse.ArgumentParser(
description="Builds libjax from source.", epilog=EPILOG)
description="Builds jaxlib from source.", epilog=EPILOG)
parser.add_argument(
"--bazel_path",
help="Path to the Bazel binary to use. The default is to find bazel via "
@ -388,6 +389,10 @@ def main():
"--bazel_options",
action="append", default=[],
help="Additional options to pass to bazel.")
parser.add_argument(
"--output_path",
default=os.path.join(cwd, "dist"),
help="Directory to which the jaxlib wheel should be written")
args = parser.parse_args()
if is_windows() and args.enable_cuda:
@ -397,6 +402,8 @@ def main():
parser.error("--cudnn_version is needed for Windows CUDA build.")
print(BANNER)
output_path = os.path.abspath(args.output_path)
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
# Find a working Bazel.
@ -447,7 +454,8 @@ def main():
config_args += ["--define=xla_python_enable_gpu=true"]
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
[":install_xla_in_source_tree", os.getcwd()])
[":build_wheel", "--",
f"--output_path={output_path}"])
print(" ".join(command))
shell(command)
shell([bazel_path, "shutdown"])

View File

@ -5,7 +5,7 @@ set -e
# Builds wheels for multiple Python versions, using pyenv instead of Docker.
# Usage: run from root of JAX source tree as:
# build/build_jaxlib_wheels_macos.sh
# The wheels will end up in build/dist.
# The wheels will end up in dist/
#
# Requires pyenv, pyenv-virtualenv (e.g., from Homebrew). If you have Homebrew
# installed, you can install these with:
@ -20,14 +20,11 @@ if ! pyenv --version 2>/dev/null ;then
fi
eval "$(pyenv init -)"
PLATFORM_TAG="macosx_10_9_x86_64"
build_jax () {
PY_VERSION="$1"
PY_TAG="$2"
NUMPY_VERSION="$3"
SCIPY_VERSION="$4"
echo -e "\nBuilding JAX for Python ${PY_VERSION}, tag ${PY_TAG}"
NUMPY_VERSION="$2"
SCIPY_VERSION="$3"
echo -e "\nBuilding JAX for Python ${PY_VERSION}"
echo "NumPy version ${NUMPY_VERSION}, SciPy version ${SCIPY_VERSION}"
pyenv install -s "${PY_VERSION}"
VENV="jax-build-${PY_VERSION}"
@ -41,17 +38,14 @@ build_jax () {
# earlier Numpy versions.
pip install numpy==$NUMPY_VERSION scipy==$SCIPY_VERSION wheel future six
rm -fr build/build
python build/build.py
cd build
python setup.py bdist_wheel --python-tag "${PY_TAG}" --plat-name "${PLATFORM_TAG}"
cd ..
python build/build.py --output_path=dist/
pyenv deactivate
pyenv virtualenv-delete -f "${VENV}"
}
rm -fr build/dist
build_jax 3.6.8 cp36 1.15.4 1.2.0
build_jax 3.7.2 cp37 1.15.4 1.2.0
build_jax 3.8.0 cp38 1.17.3 1.3.2
build_jax 3.9.0 cp39 1.19.4 1.5.4
rm -fr dist
build_jax 3.6.8 1.15.4 1.2.0
build_jax 3.7.2 1.15.4 1.2.0
build_jax 3.8.0 1.17.3 1.3.2
build_jax 3.9.0 1.19.4 1.5.4

172
build/build_wheel.py Normal file
View File

@ -0,0 +1,172 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that builds a jaxlib wheel, intended to be run via bazel run as part
# of the jaxlib build process.
# Most users should not run this script directly; use build.py instead.
import argparse
import functools
import glob
import os
import platform
import shutil
import subprocess
import sys
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
parser = argparse.ArgumentParser()
parser.add_argument(
"--sources_path",
default=None,
help="Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used.")
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
args = parser.parse_args()
r = runfiles.Create()
def _is_windows():
return sys.platform.startswith("win32")
def _copy_so(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
if not dst_filename:
if _is_windows() and src_filename.endswith(".so"):
dst_filename = src_filename[:-3] + ".pyd"
else:
dst_filename = src_filename
dst_file = os.path.join(dst_dir, dst_filename)
shutil.copy(src_file, dst_file)
def _copy_normal(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
shutil.copy(src_file, dst_file)
def copy_file(src_file, dst_dir, dst_filename=None):
if src_file.endswith(".so"):
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
else:
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
def patch_copy_xla_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
f.write(src)
def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)
def prepare_wheel(sources_path):
"""Assembles a source tree for the wheel in `sources_path`."""
jaxlib_dir = os.path.join(sources_path, "jaxlib")
os.makedirs(jaxlib_dir)
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
copy_file(r.Rlocation("__main__/jaxlib/setup.py"), dst_dir=sources_path)
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/init.py"), dst_filename="__init__.py")
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/lapack.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft.py"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
if _is_windows():
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
else:
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
patch_copy_xla_client_py(jaxlib_dir)
if not _is_windows():
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
patch_copy_tpu_client_py(jaxlib_dir)
def build_wheel(sources_path, output_path):
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
platform_name = {
"Linux": "manylinux2010",
"Darwin": "macosx_10_9",
"Windows": "win",
}[platform.system()]
cpu_name = "amd64" if platform.system() == "Windows" else "x86_64"
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
f"{sys.version_info.minor}")
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
cwd = os.getcwd()
os.chdir(sources_path)
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
python_tag_arg, platform_tag_arg])
os.chdir(cwd)
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")
sys.stderr.write(f"To install the newly-built jaxlib wheel, run:\n")
sys.stderr.write(f" pip install {output_file}\n\n")
shutil.copy(wheel, output_path)
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jaxlib")
sources_path = tmpdir.name
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(sources_path)
build_wheel(sources_path, args.output_path)
finally:
if tmpdir:
tmpdir.cleanup()

View File

@ -35,33 +35,25 @@ fi
# Builds and activates a specific Python version.
pyenv local "$PY_VERSION"
PY_TAG=$(python -c "import packaging.tags as t; print(t.interpreter_name() + t.interpreter_version())")
echo "Python tag: $PY_TAG"
# Workaround for https://github.com/bazelbuild/bazel/issues/9254
export BAZEL_LINKLIBS="-lstdc++"
export JAX_CUDA_VERSION=$3
case $2 in
cuda-included)
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
python include_cuda.py
PLAT_NAME="manylinux2010_x86_64"
;;
cuda)
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
PLAT_NAME="manylinux2010_x86_64"
;;
nocuda)
python build.py --bazel_startup_options="--output_user_root=/build/root"
PLAT_NAME="manylinux2010_x86_64"
;;
*)
usage
esac
export JAX_CUDA_VERSION=$3
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "$PLAT_NAME"
if ! python -m auditwheel show dist/jaxlib-*.whl | grep 'platform tag: "manylinux2010_x86_64"' > /dev/null; then
# Print output for debugging
python -m auditwheel show dist/jaxlib-*.whl

View File

@ -1,113 +0,0 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import shutil
import argparse
from bazel_tools.tools.python.runfiles import runfiles
parser = argparse.ArgumentParser()
parser.add_argument("target")
args = parser.parse_args()
r = runfiles.Create()
jaxlib_dir = os.path.join(args.target, "jaxlib")
def _is_windows():
return sys.platform.startswith("win32")
def _copy_so(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
if not dst_filename:
if _is_windows() and src_filename.endswith(".so"):
dst_filename = src_filename[:-3] + ".pyd"
else:
dst_filename = src_filename
dst_file = os.path.join(dst_dir, dst_filename)
shutil.copy(src_file, dst_file)
def _copy_normal(src_file, dst_dir, dst_filename=None):
src_filename = os.path.basename(src_file)
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
shutil.copy(src_file, dst_file)
def copy(src_file, dst_dir=jaxlib_dir, dst_filename=None):
if src_file.endswith(".so"):
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
else:
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
def patch_copy_xla_client_py(dst_dir=jaxlib_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python.xla_extension import ops",
"from .xla_extension import ops")
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
f.write(src)
def patch_copy_tpu_client_py(dst_dir=jaxlib_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)
shutil.rmtree(jaxlib_dir)
os.makedirs(jaxlib_dir)
copy(r.Rlocation("__main__/jaxlib/init.py"), dst_filename="__init__.py")
copy(r.Rlocation("__main__/jaxlib/lapack.so"))
copy(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
copy(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
copy(r.Rlocation("__main__/jaxlib/pocketfft.py"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
copy(r.Rlocation("__main__/jaxlib/version.py"))
copy(r.Rlocation("__main__/jaxlib/cusolver.py"))
copy(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
if _is_windows():
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
else:
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
patch_copy_xla_client_py()
if not _is_windows():
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
patch_copy_tpu_client_py()

View File

@ -31,46 +31,50 @@ Building ``jaxlib`` from source
To build ``jaxlib`` from source, you must also install some prerequisites:
* a C++ compiler (g++, clang, or MSVC)
* Numpy
* Scipy
* Cython
* six (required for during the jaxlib build only, not required at install time)
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with::
On Ubuntu or Debian you can install the necessary prerequisites with::
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3 python3-six
sudo apt install g++ python python3-dev
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
See below for Windows build instructions.
* Python packages: ``numpy``, ``scipy``, ``six``, ``wheel``.
The ``six`` package is required for during the jaxlib build only, and is not
required at install time.
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
You can install the necessary Python dependencies using ``pip``::
You can also install the necessary Python dependencies using ``pip``::
pip install numpy scipy cython six
pip install numpy scipy six wheel
To build ``jaxlib`` with CUDA support, you can run::
python build/build.py --enable_cuda
pip install -e build # installs jaxlib (includes XLA)
pip install -e dist/*.whl # installs jaxlib (includes XLA)
See ``python build/build.py --help`` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here
``python`` should be the name of your Python 3 interpreter; on some systems, you
may need to use ``python3`` instead.
may need to use ``python3`` instead. By default, the wheel is written to the
``dist/`` subdirectory of the current directory.
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``::
python build/build.py
pip install -e build # installs jaxlib (includes XLA)
pip install dist/*.whl # installs jaxlib (includes XLA)
Additional Notes for Building ``jaxlib`` from source on Windows
...............................................................
On Windows, follow `Install Visual Studio <https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019>`_
to setup latest C++ toolchain. Visual Studio 2019 version 16.5 or newer is required.
to setup a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required.
If you need to build with CUDA enabled, follow
`CUDA Installation Guide <https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html>`_
to setup CUDA environment.
@ -89,7 +93,7 @@ for more details. Install the following packages::
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
path of the current session. Ensure ``bazel``, ``patch`` and ``realpath`` are
accessible. Activate the conda environment. The following command builds with
cuda enabled, adjust it to whatever suitable for you::
CUDA enabled, adjust it to whatever suitable for you::
python .\build\build.py `
--enable_cuda `

View File

@ -80,6 +80,7 @@ py_library(
"cusolver.py",
"init.py",
"pocketfft.py",
"setup.py",
"version.py",
],
deps = [":pocketfft_flatbuffers_py"],

View File

@ -26,6 +26,7 @@ if cuda_version:
__version__ += "+cuda" + cuda_version.replace(".", "")
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
binary_libs += [os.path.basename(f) for f in glob('jaxlib/*.pyd*')]
setup(
name='jaxlib',