mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
README: improve Colab TPU installation discussion
This commit is contained in:
parent
9990ed2e64
commit
e3444a8d42
10
README.md
10
README.md
@ -506,13 +506,17 @@ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_relea
|
||||
```
|
||||
|
||||
### pip installation: Colab TPU
|
||||
Colab TPU runtimes come with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
|
||||
Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
|
||||
The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
|
||||
```python
|
||||
import jax.tools.colab_tpu
|
||||
jax.tools.colab_tpu.setup_tpu()
|
||||
```
|
||||
Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so installing `jax[tpu]` should be avoided on Colab.
|
||||
If for any reason you would like to update the jax & jaxlib libraries on a Colab TPU runtime, follow the CPU instructions above (i.e. install `jax[cpu]`).
|
||||
Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
|
||||
If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
|
||||
```
|
||||
!pip install jax<=0.3.25 jaxlib<=0.3.25
|
||||
```
|
||||
|
||||
### Conda installation
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user