Merge pull request #7083 from skye:fix_install

PiperOrigin-RevId: 381117784
This commit is contained in:
jax authors 2021-06-23 14:39:04 -07:00
commit 4c1856e1b9
2 changed files with 31 additions and 25 deletions

View File

@ -20,6 +20,7 @@ To update public copy, use:
gsutil cp jax_releases.html gs://jax-releases/
"""
from itertools import chain
import subprocess
FILENAME = "jax_releases.html"
@ -33,20 +34,34 @@ HEADER = """
FOOTER = "</body>\n</html>\n"
print("Running command: gsutil ls gs://jax-releases/cuda*")
ls_output = subprocess.check_output(["gsutil", "ls", "gs://jax-releases/cuda*"])
def get_entries(gcs_uri, whl_filter=".whl"):
entries = []
print(f"Running command: gsutil ls {gcs_uri}")
ls_output = subprocess.check_output(["gsutil", "ls", gcs_uri])
for line in ls_output.decode("utf-8").split("\n"):
# Skip incorrectly formatted wheel filenames and other gsutil output
if not whl_filter in line: continue
# Example lines:
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
# gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl
# Link title should be the innermost directory + wheel filename
# Example link titles:
# cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
# libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl
link_title_idx = line.rfind('/', 0, line.rfind('/')) + 1
link_title = line[link_title_idx:]
link_href = line.replace("gs://", "https://storage.googleapis.com/")
entries.append(f'<a href="{link_href}">{link_title}</a><br>\n')
return entries
jaxlib_cuda_entries = get_entries("gs://jax-releases/cuda*", whl_filter="+cuda")
libtpu_entries = get_entries("gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/")
print(f"Writing index to {FILENAME}")
with open(FILENAME, "w") as f:
f.write(HEADER)
for line in ls_output.decode("utf-8").split("\n"):
# Skip incorrectly formatted wheel filenames and other gsutil output
if not "+cuda" in line: continue
# Example line:
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
assert line.startswith("gs://jax-releases/cuda")
link_title = line[len("gs://jax-releases/"):]
link_href = line.replace("gs://", "https://storage.googleapis.com/")
f.write(f'<a href="{link_href}">{link_title}</a><br>\n')
for entry in chain(jaxlib_cuda_entries, libtpu_entries):
f.write(entry)
f.write(FOOTER)
print("Done.")

View File

@ -13,16 +13,10 @@
# limitations under the License.
from setuptools import setup, find_packages
import sys
# The following should be updated with each new jaxlib release.
_current_jaxlib_version = '0.1.68'
_available_cuda_versions = ['101', '102', '110', '111']
_jaxlib_cuda_url = (
f'https://storage.googleapis.com/jax-releases/cuda{{version}}/'
f'jaxlib-{_current_jaxlib_version}+cuda{{version}}'
f'-cp{sys.version_info.major}{sys.version_info.minor}-none-manylinux2010_x86_64.whl'
)
_dct = {}
with open('jax/version.py') as f:
@ -30,10 +24,7 @@ with open('jax/version.py') as f:
__version__ = _dct['__version__']
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']
_libtpu_version = '20210615'
_libtpu_url = (
f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/'
f'libtpu-nightly/libtpu_nightly-0.1.dev{_libtpu_version}-py3-none-any.whl')
_libtpu_version = '0.1.dev20210615'
setup(
name='jax',
@ -58,13 +49,13 @@ setup(
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],
# Cloud TPU VM jaxlib can be installed via:
# $ pip install jax[tpu]
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html
'tpu': [f'jaxlib=={_current_jaxlib_version}',
f'libtpu-nightly @ {_libtpu_url}'],
f'libtpu-nightly=={_libtpu_version}'],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda110]
**{f'cuda{version}': f"jaxlib @ {_jaxlib_cuda_url.format(version=version)}"
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"
for version in _available_cuda_versions}
},
url='https://github.com/google/jax',