mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Change documentation to recommend libtpu from pypi instead of GCS.
This commit is contained in:
parent
11b45a6f63
commit
85add667ed
@ -398,8 +398,8 @@ Some standouts:
|
||||
|-----------------|-----------------------------------------------------------------------------------------------------------------|
|
||||
| CPU | `pip install -U jax` |
|
||||
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
|
||||
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
|
||||
| AMD GPU (Linux) | Follow [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md). |
|
||||
| Google TPU | `pip install -U "jax[tpu]"` |
|
||||
| AMD GPU (Linux) | Follow [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md). |
|
||||
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
|
||||
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
|
||||
|
||||
|
@ -20,7 +20,7 @@ different builds for different operating systems and accelerators.
|
||||
|
||||
* **TPU (Google Cloud TPU VM)**
|
||||
```
|
||||
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
pip install -U "jax[tpu]"
|
||||
```
|
||||
|
||||
(install-supported-platforms)=
|
||||
@ -199,7 +199,7 @@ To install JAX along with appropriate versions of `jaxlib` and `libtpu`, you can
|
||||
the following in your cloud TPU VM:
|
||||
|
||||
```bash
|
||||
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
pip install "jax[tpu]"
|
||||
```
|
||||
|
||||
For users of Colab (https://colab.research.google.com/), be sure you are
|
||||
|
2
setup.py
2
setup.py
@ -73,7 +73,7 @@ setup(
|
||||
'ci': [f'jaxlib=={_latest_jaxlib_version_on_pypi}'],
|
||||
|
||||
# Cloud TPU VM jaxlib can be installed via:
|
||||
# $ pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
# $ pip install "jax[tpu]"
|
||||
'tpu': [
|
||||
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
|
||||
f'libtpu=={_libtpu_version}',
|
||||
|
Loading…
x
Reference in New Issue
Block a user