mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use local version identifiers to distribute cuda jaxlib wheels.
This change: * Updates our jaxlib build scripts to add `+cudaXXX` to the wheel version, where XXX is the CUDA version number (e.g. `110`). nocuda builds remain unchanged and do not have this extra identifier. * Adds `generate_release_index.py`, which writes an html page that pip can use to find the cuda wheels. (I based this format off of wheel PyTorch's index). * Updates the README to use the new local version identifier + wheel index. The end result is that the command to install cuda wheels is now much simpler. I manually made copies of the latest jaxlib 0.1.55 wheels that have the local version identifiers, so the new installation commands already work (as well as the old ones, until the next jaxlib release using the new tooling). Fow now, I put the html index to the GCP bucket with the wheels. We can move it to a prettier URL if/when we have one.
This commit is contained in:
parent
cfbaca0507
commit
cacb01753a
28
README.md
28
README.md
@ -408,25 +408,17 @@ and CUDNN7 installations on your machine (for example, preinstalled on your
|
||||
cloud VM), you can run
|
||||
|
||||
```bash
|
||||
# install jaxlib
|
||||
PYTHON_VERSION=cp37 # alternatives: cp36, cp37, cp38
|
||||
CUDA_VERSION=cuda100 # alternatives: cuda100, cuda101, cuda102, cuda110
|
||||
PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64
|
||||
BASE_URL='https://storage.googleapis.com/jax-releases'
|
||||
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl
|
||||
|
||||
pip install --upgrade jax # install jax
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade jax jaxlib==0.1.55+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
```
|
||||
|
||||
The library package name must correspond to the version of the existing CUDA
|
||||
The jaxlib version must correspond to the version of the existing CUDA
|
||||
installation you want to use, with `cuda110` for CUDA 11.0, `cuda102` for CUDA
|
||||
10.2, `cuda101` for CUDA 10.1, and `cuda100` for CUDA 10.0. To find your CUDA
|
||||
and CUDNN versions, you can run commands like these, depending on your CUDNN
|
||||
install path:
|
||||
10.2, `cuda101` for CUDA 10.1, and `cuda100` for CUDA 10.0. You can find your
|
||||
CUDA version with: install path:
|
||||
|
||||
```bash
|
||||
nvcc --version
|
||||
grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path
|
||||
```
|
||||
|
||||
Note that some GPU functionality expects the CUDA installation to be at
|
||||
@ -444,16 +436,6 @@ Or set the following environment variable before importing JAX:
|
||||
XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda
|
||||
```
|
||||
|
||||
The Python version must match your Python interpreter. There are prebuilt wheels
|
||||
for Python 3.6, 3.7, and 3.8; for anything else, you must build from source. Jax
|
||||
requires Python 3.6 or above. Jax does not support Python 2 any more.
|
||||
|
||||
To try automatic detection of the correct version for your system, you can run:
|
||||
|
||||
```bash
|
||||
pip install --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.55-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax
|
||||
```
|
||||
|
||||
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
|
||||
if you run into any errors or problems with the prebuilt wheels.
|
||||
|
||||
|
@ -16,7 +16,7 @@ do
|
||||
for CUDA_VARIANT in $CUDA_VARIANTS
|
||||
do
|
||||
mkdir -p dist/${CUDA_VARIANT}${CUDA_VERSION//.}
|
||||
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT
|
||||
docker run -it --tmpfs /build:exec --rm -v $(pwd)/dist:/dist jaxbuild $PYTHON_VERSION $CUDA_VARIANT $CUDA_VERSION
|
||||
mv -f dist/*.whl dist/${CUDA_VARIANT}${CUDA_VERSION//.}/
|
||||
done
|
||||
done
|
||||
|
@ -26,7 +26,7 @@ usage() {
|
||||
exit 1
|
||||
}
|
||||
|
||||
if [[ $# != 2 ]]
|
||||
if [[ $# -lt 2 ]]
|
||||
then
|
||||
usage
|
||||
fi
|
||||
@ -59,5 +59,6 @@ case $2 in
|
||||
usage
|
||||
esac
|
||||
|
||||
export JAX_CUDA_VERSION=$3
|
||||
python setup.py bdist_wheel --python-tag "$PY_TAG" --plat-name "$PLAT_NAME"
|
||||
cp -r dist/* /dist
|
||||
|
52
build/generate_release_index.py
Normal file
52
build/generate_release_index.py
Normal file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/python
|
||||
#
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generates jax_releases.html package index.
|
||||
|
||||
To update public copy, use:
|
||||
gsutil cp jax_releases.html gs://jax-releases/
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
|
||||
FILENAME = "jax_releases.html"
|
||||
|
||||
HEADER = """
|
||||
<!-- Generated by jax/build/generate_release_index.py, do not update manually! -->
|
||||
<html>
|
||||
<head><meta http-equiv="Content-Type" content="text/html; charset=utf-8"></head>
|
||||
<body>
|
||||
"""
|
||||
|
||||
FOOTER = "</body>\n</html>\n"
|
||||
|
||||
print("Running command: gsutil ls gs://jax-releases/cuda*")
|
||||
ls_output = subprocess.check_output(["gsutil", "ls", "gs://jax-releases/cuda*"])
|
||||
|
||||
print(f"Writing index to {FILENAME}")
|
||||
with open(FILENAME, "w") as f:
|
||||
f.write(HEADER)
|
||||
for line in ls_output.decode("utf-8").split("\n"):
|
||||
# Skip incorrectly formatted wheel filenames and other gsutil output
|
||||
if not "+cuda" in line: continue
|
||||
# Example line:
|
||||
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
|
||||
assert line.startswith("gs://jax-releases/cuda")
|
||||
link_title = line[len("gs://jax-releases/"):]
|
||||
link_href = line.replace("gs://", "https://storage.googleapis.com/")
|
||||
f.write(f'<a href="{link_href}">{link_title}</a><br>\n')
|
||||
f.write(FOOTER)
|
||||
print("Done.")
|
@ -21,6 +21,10 @@ __version__ = None
|
||||
with open('jaxlib/version.py') as f:
|
||||
exec(f.read(), globals())
|
||||
|
||||
cuda_version = os.environ.get("JAX_CUDA_VERSION")
|
||||
if cuda_version:
|
||||
__version__ += "+cuda" + cuda_version.replace(".", "")
|
||||
|
||||
binary_libs = [os.path.basename(f) for f in glob('jaxlib/*.so*')]
|
||||
|
||||
setup(
|
||||
|
Loading…
x
Reference in New Issue
Block a user