Improve support for pip install jax[cuda111]

This commit is contained in:
Jake VanderPlas 2021-06-23 11:42:04 -07:00
parent 0934161aea
commit dfcaf0feb8

View File

@ -13,10 +13,16 @@
# limitations under the License.
from setuptools import setup, find_packages
import sys
# The following should be updated with each new jaxlib release.
_current_jaxlib_version = '0.1.67'
_available_cuda_versions = ['101', '102', '110', '111']
_jaxlib_cuda_url = (
f'https://storage.googleapis.com/jax-releases/cuda{{version}}/'
f'jaxlib-{_current_jaxlib_version}+cuda{{version}}'
f'-cp{sys.version_info.major}{sys.version_info.minor}-none-manylinux2010_x86_64.whl'
)
_dct = {}
with open('jax/version.py') as f:
@ -57,8 +63,8 @@ setup(
f'libtpu-nightly @ {_libtpu_url}'],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"
# $ pip install jax[cuda110]
**{f'cuda{version}': f"jaxlib @ {_jaxlib_cuda_url.format(version=version)}"
for version in _available_cuda_versions}
},
url='https://github.com/google/jax',