Merge pull request #7062 from skye:jaxlib_tpu

PiperOrigin-RevId: 381063068
This commit is contained in:
jax authors 2021-06-23 10:35:30 -07:00
commit d1a25503a4

View File

@ -24,6 +24,11 @@ with open('jax/version.py') as f:
__version__ = _dct['__version__']
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']
_libtpu_version = '20210615'
_libtpu_url = (
f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/'
f'libtpu-nightly/libtpu_nightly-0.1.dev{_libtpu_version}-py3-none-any.whl')
setup(
name='jax',
version=__version__,
@ -46,6 +51,11 @@ setup(
# $ pip install jax[cpu]
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],
# Cloud TPU VM jaxlib can be installed via:
# $ pip install jax[tpu]
'tpu': [f'jaxlib=={_current_jaxlib_version}',
f'libtpu-nightly @ {_libtpu_url}'],
# CUDA installations require adding jax releases URL; e.g.
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"