mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Improve support for pip install jax[cuda111]
This commit is contained in:
parent
0934161aea
commit
dfcaf0feb8
10
setup.py
10
setup.py
@ -13,10 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import sys
|
||||
|
||||
# The following should be updated with each new jaxlib release.
|
||||
_current_jaxlib_version = '0.1.67'
|
||||
_available_cuda_versions = ['101', '102', '110', '111']
|
||||
_jaxlib_cuda_url = (
|
||||
f'https://storage.googleapis.com/jax-releases/cuda{{version}}/'
|
||||
f'jaxlib-{_current_jaxlib_version}+cuda{{version}}'
|
||||
f'-cp{sys.version_info.major}{sys.version_info.minor}-none-manylinux2010_x86_64.whl'
|
||||
)
|
||||
|
||||
_dct = {}
|
||||
with open('jax/version.py') as f:
|
||||
@ -57,8 +63,8 @@ setup(
|
||||
f'libtpu-nightly @ {_libtpu_url}'],
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
||||
**{f'cuda{version}': f"jaxlib=={_current_jaxlib_version}+cuda{version}"
|
||||
# $ pip install jax[cuda110]
|
||||
**{f'cuda{version}': f"jaxlib @ {_jaxlib_cuda_url.format(version=version)}"
|
||||
for version in _available_cuda_versions}
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
|
Loading…
x
Reference in New Issue
Block a user