mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change jaxlib build rules to build a wheel, rather than writing output to the source directory.
This commit is contained in:
parent
5a41779fbe
commit
c06ead6b04
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,15 +1,17 @@
|
||||
*.pyc
|
||||
*.so
|
||||
*.egg-info
|
||||
*.whl
|
||||
build/bazel*
|
||||
dist/
|
||||
.ipynb_checkpoints
|
||||
/bazel-*
|
||||
.bazelrc
|
||||
/tensorflow
|
||||
.DS_Store
|
||||
build/
|
||||
dist/
|
||||
.mypy_cache/
|
||||
.pytype/
|
||||
docs/build
|
||||
docs/notebooks/.ipynb_checkpoints/
|
||||
docs/_autosummary
|
||||
.idea
|
||||
|
@ -22,8 +22,8 @@ licenses(["notice"]) # Apache 2
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
py_binary(
|
||||
name = "install_xla_in_source_tree",
|
||||
srcs = ["install_xla_in_source_tree.py"],
|
||||
name = "build_wheel",
|
||||
srcs = ["build_wheel.py"],
|
||||
data = [
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
|
||||
"//jaxlib",
|
||||
|
@ -14,10 +14,10 @@ RUN /pyenv/bin/pyenv install 3.8.0
|
||||
RUN /pyenv/bin/pyenv install 3.9.0
|
||||
|
||||
# We pin numpy to a version < 1.16 to avoid version compatibility issues.
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy cython setuptools wheel packaging six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy cython setuptools wheel packaging six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 cython setuptools wheel packaging six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.6.8 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.7.2 && pip install numpy==1.15.4 scipy==1.5.4 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.8.0 && pip install numpy==1.17.3 scipy==1.5.4 setuptools wheel six auditwheel
|
||||
RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.9.0 && pip install numpy==1.19.4 scipy==1.5.4 setuptools wheel six auditwheel
|
||||
|
||||
# Change the CUDA version if it doesn't match the installed version.
|
||||
ARG JAX_CUDA_VERSION=10.0
|
||||
|
@ -333,8 +333,9 @@ def add_boolean_argument(parser, name, default=False, help_str=None):
|
||||
|
||||
|
||||
def main():
|
||||
cwd = os.getcwd()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Builds libjax from source.", epilog=EPILOG)
|
||||
description="Builds jaxlib from source.", epilog=EPILOG)
|
||||
parser.add_argument(
|
||||
"--bazel_path",
|
||||
help="Path to the Bazel binary to use. The default is to find bazel via "
|
||||
@ -388,6 +389,10 @@ def main():
|
||||
"--bazel_options",
|
||||
action="append", default=[],
|
||||
help="Additional options to pass to bazel.")
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=os.path.join(cwd, "dist"),
|
||||
help="Directory to which the jaxlib wheel should be written")
|
||||
args = parser.parse_args()
|
||||
|
||||
if is_windows() and args.enable_cuda:
|
||||
@ -397,6 +402,8 @@ def main():
|
||||
parser.error("--cudnn_version is needed for Windows CUDA build.")
|
||||
|
||||
print(BANNER)
|
||||
|
||||
output_path = os.path.abspath(args.output_path)
|
||||
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
|
||||
|
||||
# Find a working Bazel.
|
||||
@ -447,7 +454,8 @@ def main():
|
||||
config_args += ["--define=xla_python_enable_gpu=true"]
|
||||
command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] + config_args +
|
||||
[":install_xla_in_source_tree", os.getcwd()])
|
||||
[":build_wheel", "--",
|
||||
f"--output_path={output_path}"])
|
||||
print(" ".join(command))
|
||||
shell(command)
|
||||
shell([bazel_path, "shutdown"])
|
||||
|
@ -5,7 +5,7 @@ set -e
|
||||
# Builds wheels for multiple Python versions, using pyenv instead of Docker.
|
||||
# Usage: run from root of JAX source tree as:
|
||||
# build/build_jaxlib_wheels_macos.sh
|
||||
# The wheels will end up in build/dist.
|
||||
# The wheels will end up in dist/
|
||||
#
|
||||
# Requires pyenv, pyenv-virtualenv (e.g., from Homebrew). If you have Homebrew
|
||||
# installed, you can install these with:
|
||||
@ -20,14 +20,11 @@ if ! pyenv --version 2>/dev/null ;then
|
||||
fi
|
||||
eval "$(pyenv init -)"
|
||||
|
||||
PLATFORM_TAG="macosx_10_9_x86_64"
|
||||
|
||||
build_jax () {
|
||||
PY_VERSION="$1"
|
||||
PY_TAG="$2"
|
||||
NUMPY_VERSION="$3"
|
||||
SCIPY_VERSION="$4"
|
||||
echo -e "\nBuilding JAX for Python ${PY_VERSION}, tag ${PY_TAG}"
|
||||
NUMPY_VERSION="$2"
|
||||
SCIPY_VERSION="$3"
|
||||
echo -e "\nBuilding JAX for Python ${PY_VERSION}"
|
||||
echo "NumPy version ${NUMPY_VERSION}, SciPy version ${SCIPY_VERSION}"
|
||||
pyenv install -s "${PY_VERSION}"
|
||||
VENV="jax-build-${PY_VERSION}"
|
||||
@ -41,17 +38,14 @@ build_jax () {
|
||||
# earlier Numpy versions.
|
||||
pip install numpy==$NUMPY_VERSION scipy==$SCIPY_VERSION wheel future six
|
||||
rm -fr build/build
|
||||
python build/build.py
|
||||
cd build
|
||||
python setup.py bdist_wheel --python-tag "${PY_TAG}" --plat-name "${PLATFORM_TAG}"
|
||||
cd ..
|
||||
python build/build.py --output_path=dist/
|
||||
pyenv deactivate
|
||||
pyenv virtualenv-delete -f "${VENV}"
|
||||
}
|
||||
|
||||
|
||||
rm -fr build/dist
|
||||
build_jax 3.6.8 cp36 1.15.4 1.2.0
|
||||
build_jax 3.7.2 cp37 1.15.4 1.2.0
|
||||
build_jax 3.8.0 cp38 1.17.3 1.3.2
|
||||
build_jax 3.9.0 cp39 1.19.4 1.5.4
|
||||
rm -fr dist
|
||||
build_jax 3.6.8 1.15.4 1.2.0
|
||||
build_jax 3.7.2 1.15.4 1.2.0
|
||||
build_jax 3.8.0 1.17.3 1.3.2
|
||||
build_jax 3.9.0 1.19.4 1.5.4
|
||||
|
172
build/build_wheel.py
Normal file
172
build/build_wheel.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script that builds a jaxlib wheel, intended to be run via bazel run as part
|
||||
# of the jaxlib build process.
|
||||
|
||||
# Most users should not run this script directly; use build.py instead.
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help="Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used.")
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.")
|
||||
args = parser.parse_args()
|
||||
|
||||
r = runfiles.Create()
|
||||
|
||||
|
||||
def _is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
def _copy_so(src_file, dst_dir, dst_filename=None):
|
||||
src_filename = os.path.basename(src_file)
|
||||
if not dst_filename:
|
||||
if _is_windows() and src_filename.endswith(".so"):
|
||||
dst_filename = src_filename[:-3] + ".pyd"
|
||||
else:
|
||||
dst_filename = src_filename
|
||||
dst_file = os.path.join(dst_dir, dst_filename)
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def _copy_normal(src_file, dst_dir, dst_filename=None):
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def copy_file(src_file, dst_dir, dst_filename=None):
|
||||
if src_file.endswith(".so"):
|
||||
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
|
||||
else:
|
||||
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
|
||||
|
||||
def patch_copy_xla_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
|
||||
def patch_copy_tpu_client_py(dst_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
|
||||
"from . import xla_client")
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
|
||||
"from . import tpu_client_extension as _tpu_client")
|
||||
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
|
||||
def prepare_wheel(sources_path):
|
||||
"""Assembles a source tree for the wheel in `sources_path`."""
|
||||
jaxlib_dir = os.path.join(sources_path, "jaxlib")
|
||||
os.makedirs(jaxlib_dir)
|
||||
copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)
|
||||
|
||||
copy_file(r.Rlocation("__main__/jaxlib/setup.py"), dst_dir=sources_path)
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/init.py"), dst_filename="__init__.py")
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/lapack.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/pocketfft.py"))
|
||||
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
|
||||
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/version.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cusolver.py"))
|
||||
copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
|
||||
|
||||
if _is_windows():
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
|
||||
else:
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
|
||||
patch_copy_xla_client_py(jaxlib_dir)
|
||||
|
||||
if not _is_windows():
|
||||
copy_to_jaxlib(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
||||
def build_wheel(sources_path, output_path):
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
platform_name = {
|
||||
"Linux": "manylinux2010",
|
||||
"Darwin": "macosx_10_9",
|
||||
"Windows": "win",
|
||||
}[platform.system()]
|
||||
cpu_name = "amd64" if platform.system() == "Windows" else "x86_64"
|
||||
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
|
||||
f"{sys.version_info.minor}")
|
||||
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
|
||||
cwd = os.getcwd()
|
||||
os.chdir(sources_path)
|
||||
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
|
||||
python_tag_arg, platform_tag_arg])
|
||||
os.chdir(cwd)
|
||||
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
|
||||
output_file = os.path.join(output_path, os.path.basename(wheel))
|
||||
sys.stderr.write(f"Output wheel: {output_file}\n\n")
|
||||
sys.stderr.write(f"To install the newly-built jaxlib wheel, run:\n")
|
||||
sys.stderr.write(f" pip install {output_file}\n\n")
|
||||
shutil.copy(wheel, output_path)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
if sources_path is None:
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jaxlib")
|
||||
sources_path = tmpdir.name
|
||||
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_wheel(sources_path)
|
||||
build_wheel(sources_path, args.output_path)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
@ -35,33 +35,25 @@ fi
|
||||
# Builds and activates a specific Python version.
|
||||
pyenv local "$PY_VERSION"
|
||||
|
||||
PY_TAG=$(python -c "import packaging.tags as t; print(t.interpreter_name() + t.interpreter_version())")
|
||||
|
||||
echo "Python tag: $PY_TAG"
|
||||
|
||||
# Workaround for https://github.com/bazelbuild/bazel/issues/9254
|
||||
export BAZEL_LINKLIBS="-lstdc++"
|
||||
|
||||
export JAX_CUDA_VERSION=$3
|
||||
case $2 in
|
||||
cuda-included)
|
||||
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
|
||||
python include_cuda.py
|
||||
PLAT_NAME="manylinux2010_x86_64"
|
||||
;;
|
||||
cuda)
|
||||
python build.py --enable_cuda --bazel_startup_options="--output_user_root=/build/root"
|
||||
PLAT_NAME="manylinux2010_x86_64"
|
||||
;;
|
||||
nocuda)
|
||||
python build.py --bazel_startup_options="--output_user_root=/build/root"
|
||||
PLAT_NAME="manylinux2010_x86_64"
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
esac
|
||||
|
||||
export JAX_CUDA_VERSION=$3
|
||||
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "$PLAT_NAME"
|
||||
if ! python -m auditwheel show dist/jaxlib-*.whl | grep 'platform tag: "manylinux2010_x86_64"' > /dev/null; then
|
||||
# Print output for debugging
|
||||
python -m auditwheel show dist/jaxlib-*.whl
|
||||
|
@ -1,113 +0,0 @@
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("target")
|
||||
args = parser.parse_args()
|
||||
|
||||
r = runfiles.Create()
|
||||
|
||||
jaxlib_dir = os.path.join(args.target, "jaxlib")
|
||||
|
||||
def _is_windows():
|
||||
return sys.platform.startswith("win32")
|
||||
|
||||
|
||||
def _copy_so(src_file, dst_dir, dst_filename=None):
|
||||
src_filename = os.path.basename(src_file)
|
||||
if not dst_filename:
|
||||
if _is_windows() and src_filename.endswith(".so"):
|
||||
dst_filename = src_filename[:-3] + ".pyd"
|
||||
else:
|
||||
dst_filename = src_filename
|
||||
dst_file = os.path.join(dst_dir, dst_filename)
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def _copy_normal(src_file, dst_dir, dst_filename=None):
|
||||
src_filename = os.path.basename(src_file)
|
||||
dst_file = os.path.join(dst_dir, dst_filename or src_filename)
|
||||
shutil.copy(src_file, dst_file)
|
||||
|
||||
|
||||
def copy(src_file, dst_dir=jaxlib_dir, dst_filename=None):
|
||||
if src_file.endswith(".so"):
|
||||
_copy_so(src_file, dst_dir, dst_filename=dst_filename)
|
||||
else:
|
||||
_copy_normal(src_file, dst_dir, dst_filename=dst_filename)
|
||||
|
||||
|
||||
def patch_copy_xla_client_py(dst_dir=jaxlib_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
src = src.replace("from tensorflow.compiler.xla.python.xla_extension import ops",
|
||||
"from .xla_extension import ops")
|
||||
with open(os.path.join(dst_dir, "xla_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
|
||||
def patch_copy_tpu_client_py(dst_dir=jaxlib_dir):
|
||||
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
|
||||
src = f.read()
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
|
||||
"from . import xla_extension as _xla")
|
||||
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
|
||||
"from . import xla_client")
|
||||
src = src.replace(
|
||||
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
|
||||
"from . import tpu_client_extension as _tpu_client")
|
||||
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
|
||||
f.write(src)
|
||||
|
||||
shutil.rmtree(jaxlib_dir)
|
||||
os.makedirs(jaxlib_dir)
|
||||
|
||||
copy(r.Rlocation("__main__/jaxlib/init.py"), dst_filename="__init__.py")
|
||||
copy(r.Rlocation("__main__/jaxlib/lapack.so"))
|
||||
copy(r.Rlocation("__main__/jaxlib/_pocketfft.so"))
|
||||
copy(r.Rlocation("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py"))
|
||||
copy(r.Rlocation("__main__/jaxlib/pocketfft.py"))
|
||||
if r.Rlocation("__main__/jaxlib/cusolver_kernels.so") is not None:
|
||||
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.so"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.so"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.so"))
|
||||
if r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd") is not None:
|
||||
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cublas_kernels.pyd"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cusolver_kernels.pyd"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cuda_prng_kernels.pyd"))
|
||||
copy(r.Rlocation("__main__/jaxlib/version.py"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cusolver.py"))
|
||||
copy(r.Rlocation("__main__/jaxlib/cuda_prng.py"))
|
||||
|
||||
if _is_windows():
|
||||
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.pyd"))
|
||||
else:
|
||||
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so"))
|
||||
patch_copy_xla_client_py()
|
||||
|
||||
if not _is_windows():
|
||||
copy(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"))
|
||||
patch_copy_tpu_client_py()
|
@ -31,46 +31,50 @@ Building ``jaxlib`` from source
|
||||
To build ``jaxlib`` from source, you must also install some prerequisites:
|
||||
|
||||
* a C++ compiler (g++, clang, or MSVC)
|
||||
* Numpy
|
||||
* Scipy
|
||||
* Cython
|
||||
* six (required for during the jaxlib build only, not required at install time)
|
||||
|
||||
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with::
|
||||
On Ubuntu or Debian you can install the necessary prerequisites with::
|
||||
|
||||
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3 python3-six
|
||||
sudo apt install g++ python python3-dev
|
||||
|
||||
If you are building on a Mac, make sure XCode and the XCode command line tools
|
||||
are installed.
|
||||
|
||||
See below for Windows build instructions.
|
||||
|
||||
* Python packages: ``numpy``, ``scipy``, ``six``, ``wheel``.
|
||||
|
||||
The ``six`` package is required for during the jaxlib build only, and is not
|
||||
required at install time.
|
||||
|
||||
|
||||
If you are building on a Mac, make sure XCode and the XCode command line tools
|
||||
are installed.
|
||||
You can install the necessary Python dependencies using ``pip``::
|
||||
|
||||
You can also install the necessary Python dependencies using ``pip``::
|
||||
|
||||
pip install numpy scipy cython six
|
||||
pip install numpy scipy six wheel
|
||||
|
||||
|
||||
To build ``jaxlib`` with CUDA support, you can run::
|
||||
|
||||
python build/build.py --enable_cuda
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
pip install -e dist/*.whl # installs jaxlib (includes XLA)
|
||||
|
||||
|
||||
See ``python build/build.py --help`` for configuration options, including ways to
|
||||
specify the paths to CUDA and CUDNN, which you must have installed. Here
|
||||
``python`` should be the name of your Python 3 interpreter; on some systems, you
|
||||
may need to use ``python3`` instead.
|
||||
may need to use ``python3`` instead. By default, the wheel is written to the
|
||||
``dist/`` subdirectory of the current directory.
|
||||
|
||||
To build ``jaxlib`` without CUDA GPU support (CPU only), drop the ``--enable_cuda``::
|
||||
|
||||
python build/build.py
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
pip install dist/*.whl # installs jaxlib (includes XLA)
|
||||
|
||||
|
||||
Additional Notes for Building ``jaxlib`` from source on Windows
|
||||
...............................................................
|
||||
|
||||
On Windows, follow `Install Visual Studio <https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019>`_
|
||||
to setup latest C++ toolchain. Visual Studio 2019 version 16.5 or newer is required.
|
||||
to setup a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required.
|
||||
If you need to build with CUDA enabled, follow
|
||||
`CUDA Installation Guide <https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html>`_
|
||||
to setup CUDA environment.
|
||||
@ -89,7 +93,7 @@ for more details. Install the following packages::
|
||||
Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
|
||||
path of the current session. Ensure ``bazel``, ``patch`` and ``realpath`` are
|
||||
accessible. Activate the conda environment. The following command builds with
|
||||
cuda enabled, adjust it to whatever suitable for you::
|
||||
CUDA enabled, adjust it to whatever suitable for you::
|
||||
|
||||
python .\build\build.py `
|
||||
--enable_cuda `
|
||||
|
@ -80,6 +80,7 @@ py_library(
|
||||
"cusolver.py",
|
||||
"init.py",
|
||||
"pocketfft.py",
|
||||
"setup.py",
|
||||
"version.py",
|
||||
],
|
||||
deps = [":pocketfft_flatbuffers_py"],
|
||||
|
@ -26,6 +26,7 @@ if cuda_version:
|
||||
__version__ += "+cuda" + cuda_version.replace(".", "")
|
||||
|
||||
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
|
||||
binary_libs += [os.path.basename(f) for f in glob('jaxlib/*.pyd*')]
|
||||
|
||||
setup(
|
||||
name='jaxlib',
|
Loading…
x
Reference in New Issue
Block a user