Use local version identifiers to distribute cuda jaxlib wheels.

This change:

* Updates our jaxlib build scripts to add `+cudaXXX` to the wheel
  version, where XXX is the CUDA version number (e.g. `110`). nocuda
  builds remain unchanged and do not have this extra identifier.

* Adds `generate_release_index.py`, which writes an html page that pip
  can use to find the cuda wheels. (I based this format off of
  wheel PyTorch's index).

* Updates the README to use the new local version identifier + wheel
  index.

The end result is that the command to install cuda wheels is now much
simpler.

I manually made copies of the latest jaxlib 0.1.55 wheels that have
the local version identifiers, so the new installation commands
already work (as well as the old ones, until the next jaxlib release
using the new tooling).

Fow now, I put the html index to the GCP bucket with the wheels. We
can move it to a prettier URL if/when we have one.
This commit is contained in:
Skye Wanderman-Milne 2020-08-13 17:30:08 -07:00
parent cfbaca0507
commit cacb01753a
5 changed files with 64 additions and 25 deletions

View File

@ -408,25 +408,17 @@ and CUDNN7 installations on your machine (for example, preinstalled on your
cloud VM), you can run
```bash
# install jaxlib
PYTHON_VERSION=cp37 # alternatives: cp36, cp37, cp38
CUDA_VERSION=cuda100 # alternatives: cuda100, cuda101, cuda102, cuda110
PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax # install jax
pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.55+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
The library package name must correspond to the version of the existing CUDA
The jaxlib version must correspond to the version of the existing CUDA
installation you want to use, with `cuda110` for CUDA 11.0, `cuda102` for CUDA
10.2, `cuda101` for CUDA 10.1, and `cuda100` for CUDA 10.0. To find your CUDA
and CUDNN versions, you can run commands like these, depending on your CUDNN
install path:
10.2, `cuda101` for CUDA 10.1, and `cuda100` for CUDA 10.0. You can find your
CUDA version with: install path:
```bash
nvcc --version
grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path
```
Note that some GPU functionality expects the CUDA installation to be at
@ -444,16 +436,6 @@ Or set the following environment variable before importing JAX:
XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda
```
The Python version must match your Python interpreter. There are prebuilt wheels
for Python 3.6, 3.7, and 3.8; for anything else, you must build from source. Jax
requires Python 3.6 or above. Jax does not support Python 2 any more.
To try automatic detection of the correct version for your system, you can run:
```bash
pip install --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.55-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
```
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.

View File

@ -16,7 +16,7 @@ do
for CUDA_VARIANT in $CUDA_VARIANTS
do
mkdir -p dist/${CUDA_VARIANT}${CUDA_VERSION//.}
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT $CUDA_VERSION
mv -f dist/*.whl dist/${CUDA_VARIANT}${CUDA_VERSION//.}/
done
done

View File

@ -26,7 +26,7 @@ usage() {
exit 1
}
if [[ $# != 2 ]]
if [[ $# -lt 2 ]]
then
usage
fi
@ -59,5 +59,6 @@ case $2 in
usage
esac
export JAX_CUDA_VERSION=$3
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "$PLAT_NAME"
cp -r dist/* /dist

View File

@ -0,0 +1,52 @@
#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates jax_releases.html package index.
To update public copy, use:
gsutil cp jax_releases.html gs://jax-releases/
"""
import subprocess
FILENAME = "jax_releases.html"
HEADER = """
<!-- Generated by jax/build/generate_release_index.py, do not update manually! -->
<html>
<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"></head>
<body>
"""
FOOTER = "</body>\n</html>\n"
print("Running command: gsutil ls gs://jax-releases/cuda*")
ls_output = subprocess.check_output(["gsutil", "ls", "gs://jax-releases/cuda*"])
print(f"Writing index to {FILENAME}")
with open(FILENAME, "w") as f:
f.write(HEADER)
for line in ls_output.decode("utf-8").split("\n"):
# Skip incorrectly formatted wheel filenames and other gsutil output
if not "+cuda" in line: continue
# Example line:
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
assert line.startswith("gs://jax-releases/cuda")
link_title = line[len("gs://jax-releases/"):]
link_href = line.replace("gs://", "https://storage.googleapis.com/")
f.write(f'<a href="{link_href}">{link_title}</a><br>\n')
f.write(FOOTER)
print("Done.")

View File

@ -21,6 +21,10 @@ __version__ = None
with open('jaxlib/version.py') as f:
exec(f.read(), globals())
cuda_version = os.environ.get("JAX_CUDA_VERSION")
if cuda_version:
__version__ += "+cuda" + cuda_version.replace(".", "")
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
setup(