mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #7062 from skye:jaxlib_tpu
PiperOrigin-RevId: 381063068
This commit is contained in:
commit
d1a25503a4
10
setup.py
10
setup.py
@ -24,6 +24,11 @@ with open('jax/version.py') as f:
|
||||
__version__ = _dct['__version__']
|
||||
_minimum_jaxlib_version = _dct['_minimum_jaxlib_version']
|
||||
|
||||
_libtpu_version = '20210615'
|
||||
_libtpu_url = (
|
||||
f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/'
|
||||
f'libtpu-nightly/libtpu_nightly-0.1.dev{_libtpu_version}-py3-none-any.whl')
|
||||
|
||||
setup(
|
||||
name='jax',
|
||||
version=__version__,
|
||||
@ -46,6 +51,11 @@ setup(
|
||||
# $ pip install jax[cpu]
|
||||
'cpu': [f'jaxlib>={_minimum_jaxlib_version}'],
|
||||
|
||||
# Cloud TPU VM jaxlib can be installed via:
|
||||
# $ pip install jax[tpu]
|
||||
'tpu': [f'jaxlib=={_current_jaxlib_version}',
|
||||
f'libtpu-nightly @ {_libtpu_url}'],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user