mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix pip install jax[tpu]
* Updates jax_releases.html index to include libtpu wheels * Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL The full install command will now be: `pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html` (similar to the cuda install commands) I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
This commit is contained in:
parent
2460f91561
commit
55276d15e4
@ -20,6 +20,7 @@ To update public copy, use:
|
||||
gsutil cp jax_releases.html gs://jax-releases/
|
||||
"""
|
||||
|
||||
from itertools import chain
|
||||
import subprocess
|
||||
|
||||
FILENAME = "jax_releases.html"
|
||||
@ -33,20 +34,34 @@ HEADER = """
|
||||
|
||||
FOOTER = "</body>\n</html>\n"
|
||||
|
||||
print("Running command: gsutil ls gs://jax-releases/cuda*")
|
||||
ls_output = subprocess.check_output(["gsutil", "ls", "gs://jax-releases/cuda*"])
|
||||
def get_entries(gcs_uri, whl_filter=".whl"):
|
||||
entries = []
|
||||
print(f"Running command: gsutil ls {gcs_uri}")
|
||||
ls_output = subprocess.check_output(["gsutil", "ls", gcs_uri])
|
||||
for line in ls_output.decode("utf-8").split("\n"):
|
||||
# Skip incorrectly formatted wheel filenames and other gsutil output
|
||||
if not whl_filter in line: continue
|
||||
# Example lines:
|
||||
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
|
||||
# gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl
|
||||
|
||||
# Link title should be the innermost directory + wheel filename
|
||||
# Example link titles:
|
||||
# cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
|
||||
# libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl
|
||||
link_title_idx = line.rfind('/', 0, line.rfind('/')) + 1
|
||||
link_title = line[link_title_idx:]
|
||||
link_href = line.replace("gs://", "https://storage.googleapis.com/")
|
||||
entries.append(f'<a href="{link_href}">{link_title}</a><br>\n')
|
||||
return entries
|
||||
|
||||
jaxlib_cuda_entries = get_entries("gs://jax-releases/cuda*", whl_filter="+cuda")
|
||||
libtpu_entries = get_entries("gs://cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/")
|
||||
|
||||
print(f"Writing index to {FILENAME}")
|
||||
with open(FILENAME, "w") as f:
|
||||
f.write(HEADER)
|
||||
for line in ls_output.decode("utf-8").split("\n"):
|
||||
# Skip incorrectly formatted wheel filenames and other gsutil output
|
||||
if not "+cuda" in line: continue
|
||||
# Example line:
|
||||
# gs://jax-releases/cuda101/jaxlib-0.1.52+cuda101-cp38-none-manylinux2010_x86_64.whl
|
||||
assert line.startswith("gs://jax-releases/cuda")
|
||||
link_title = line[len("gs://jax-releases/"):]
|
||||
link_href = line.replace("gs://", "https://storage.googleapis.com/")
|
||||
f.write(f'<a href="{link_href}">{link_title}</a><br>\n')
|
||||
for entry in chain(jaxlib_cuda_entries, libtpu_entries):
|
||||
f.write(entry)
|
||||
f.write(FOOTER)
|
||||
print("Done.")
|
||||
|
9
setup.py
9
setup.py
@ -24,10 +24,7 @@ with open('jax/version.py') as f:
|
||||
__version__ = _dct['__version__']
|
||||
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']
|
||||
|
||||
_libtpu_version = '20210615'
|
||||
_libtpu_url = (
|
||||
f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/'
|
||||
f'libtpu-nightly/libtpu_nightly-0.1.dev{_libtpu_version}-py3-none-any.whl')
|
||||
_libtpu_version = '0.1.dev20210615'
|
||||
|
||||
setup(
|
||||
name='jax',
|
||||
@ -52,9 +49,9 @@ setup(
|
||||
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],
|
||||
|
||||
# Cloud TPU VM jaxlib can be installed via:
|
||||
# $ pip install jax[tpu]
|
||||
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
'tpu': [f'jaxlib=={_current_jaxlib_version}',
|
||||
f'libtpu-nightly @ {_libtpu_url}'],
|
||||
f'libtpu-nightly=={_libtpu_version}'],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
|
Loading…
x
Reference in New Issue
Block a user