mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Prepare for jax and jaxlib 0.4.0 release
PiperOrigin-RevId: 493733609
This commit is contained in:
parent
dd647601c6
commit
0118f8d568
18
README.md
18
README.md
@ -427,15 +427,15 @@ learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
|
||||
package.
|
||||
|
||||
JAX provides pre-built CUDA-compatible wheels for **Linux only**,
|
||||
with CUDA 11.1 or newer, and CuDNN 8.0.5 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
|
||||
with CUDA 11.4 or newer, and CuDNN 8.2 or newer. Note these existing wheels are currently for `x86_64` architectures only. Other combinations of
|
||||
operating system, CUDA, and CuDNN are possible, but require [building from
|
||||
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
|
||||
* CUDA 11.1 or newer is *required*.
|
||||
* CUDA 11.4 or newer is *required*.
|
||||
* The supported cuDNN versions for the prebuilt wheels are:
|
||||
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
|
||||
* cuDNN 8.6 or newer. We recommend using the cuDNN 8.6 wheel if your cuDNN
|
||||
installation is new enough, since it supports additional functionality.
|
||||
* cuDNN 8.0.5 or newer.
|
||||
* cuDNN 8.2 or newer.
|
||||
* You *must* use an NVidia driver version that is at least as new as your
|
||||
[CUDA toolkit's corresponding driver version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions).
|
||||
For example, if you have CUDA 11.4 update 4 installed, you must use NVidia
|
||||
@ -453,7 +453,7 @@ Next, run
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
|
||||
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
|
||||
# Note: wheels only available on linux.
|
||||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
@ -468,11 +468,11 @@ version for jaxlib explicitly:
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
|
||||
# Installs the wheel compatible with Cuda >= 11.8 and cudnn >= 8.6
|
||||
pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
|
||||
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
|
||||
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
```
|
||||
|
||||
You can find your CUDA version with the command:
|
||||
@ -483,7 +483,7 @@ nvcc --version
|
||||
|
||||
Some GPU functionality expects the CUDA installation to be at
|
||||
`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
|
||||
(e.g. `cuda-11.1`). If CUDA is installed elsewhere on your system, you can either
|
||||
(e.g. `cuda-11.8`). If CUDA is installed elsewhere on your system, you can either
|
||||
create a symlink:
|
||||
|
||||
```bash
|
||||
|
@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
# and update the sha256 with the result.
|
||||
http_archive(
|
||||
name = "org_tensorflow",
|
||||
sha256 = "bfd40279b247d2d0b0dc5c5a776b595c9d4979889dcf0529c85fe9f6ff7a5255",
|
||||
strip_prefix = "tensorflow-c21f137bc42450f10f7d04f9d263852827afd079",
|
||||
sha256 = "47edef97c9b23661fd63621d522454f30772ac70a1fb5ff82864e566ef86be78",
|
||||
strip_prefix = "tensorflow-f3cc513887e06150b6f870c522220dabedc58920",
|
||||
urls = [
|
||||
"https://github.com/tensorflow/tensorflow/archive/c21f137bc42450f10f7d04f9d263852827afd079.tar.gz",
|
||||
"https://github.com/tensorflow/tensorflow/archive/f3cc513887e06150b6f870c522220dabedc58920.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,7 +22,7 @@ from jax.config import config
|
||||
TPU_DRIVER_MODE = 0
|
||||
|
||||
|
||||
def setup_tpu(tpu_driver_version='tpu_driver_20221109'):
|
||||
def setup_tpu(tpu_driver_version='tpu_driver_20221207'):
|
||||
"""Sets up Colab to run on TPU.
|
||||
|
||||
Note: make sure the Colab Runtime is set to Accelerator: TPU.
|
||||
|
6
setup.py
6
setup.py
@ -19,14 +19,14 @@ import sys
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
_current_jaxlib_version = '0.3.25'
|
||||
_current_jaxlib_version = '0.4.0'
|
||||
# The following should be updated with each new jaxlib release.
|
||||
_latest_jaxlib_version_on_pypi = '0.3.25'
|
||||
_available_cuda_versions = ['11']
|
||||
_default_cuda_version = '11'
|
||||
_available_cudnn_versions = ['82', '86']
|
||||
_default_cudnn_version = '86'
|
||||
_libtpu_version = '0.1.dev20221109'
|
||||
_libtpu_version = '0.1.dev20221207'
|
||||
|
||||
_dct = {}
|
||||
with open('jax/version.py') as f:
|
||||
@ -96,7 +96,7 @@ setup(
|
||||
|
||||
# CUDA installations require adding jax releases URL; e.g.
|
||||
# $ pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
# $ pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
# $ pip install jax[cuda11_cudnn86] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
**{f'cuda{cuda_version}_cudnn{cudnn_version}': f"jaxlib=={_current_jaxlib_version}+cuda{cuda_version}.cudnn{cudnn_version}"
|
||||
for cuda_version in _available_cuda_versions for cudnn_version in _available_cudnn_versions}
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user