mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add cuda_pip
extra for jaxlib
PiperOrigin-RevId: 534957585
This commit is contained in:
parent
bf8ed6a543
commit
557ca52f10
@ -47,6 +47,28 @@ setup(
|
||||
packages=['jaxlib', 'jaxlib.xla_extension'],
|
||||
python_requires='>=3.8',
|
||||
install_requires=['scipy>=1.7', 'numpy>=1.21', 'ml_dtypes>=0.1.0'],
|
||||
extras_require={
|
||||
'cuda11_pip': [
|
||||
"nvidia-cublas-cu11>=11.11",
|
||||
"nvidia-cuda-cupti-cu11>=11.8",
|
||||
"nvidia-cuda-nvcc-cu11>=11.8",
|
||||
"nvidia-cuda-runtime-cu11>=11.8",
|
||||
"nvidia-cudnn-cu11>=8.6",
|
||||
"nvidia-cufft-cu11>=10.9",
|
||||
"nvidia-cusolver-cu11>=11.4",
|
||||
"nvidia-cusparse-cu11>=11.7",
|
||||
],
|
||||
'cuda12_pip': [
|
||||
"nvidia-cublas-cu12",
|
||||
"nvidia-cuda-cupti-cu12",
|
||||
"nvidia-cuda-nvcc-cu12",
|
||||
"nvidia-cuda-runtime-cu12",
|
||||
"nvidia-cudnn-cu12",
|
||||
"nvidia-cufft-cu12",
|
||||
"nvidia-cusolver-cu12",
|
||||
"nvidia-cusparse-cu12",
|
||||
],
|
||||
},
|
||||
url='https://github.com/google/jax',
|
||||
license='Apache-2.0',
|
||||
classifiers=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user