README: improve Colab TPU installation discussion

This commit is contained in:
Jake VanderPlas 2023-03-15 08:54:23 -07:00
parent 9990ed2e64
commit e3444a8d42

View File

@ -506,13 +506,17 @@ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_relea
```
### pip installation: Colab TPU
Colab TPU runtimes come with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so the installation instructions differ.
The Colab TPU runtime comes with JAX pre-installed, but before importing JAX you must run the following code to initialize the TPU:
```python
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
```
Colab TPU runtimes use an older TPU architecture than Cloud TPU VMs, so installing `jax[tpu]` should be avoided on Colab.
If for any reason you would like to update the jax & jaxlib libraries on a Colab TPU runtime, follow the CPU instructions above (i.e. install `jax[cpu]`).
Note that Colab TPU runtimes are not compatible with JAX version 0.4.0 or newer.
If you need to re-install JAX on a Colab TPU runtime, you can use the following command:
```
!pip install jax<=0.3.25 jaxlib<=0.3.25
```
### Conda installation