Use shorter error message for jax.tools.colab_tpu.setup_tpu()

This commit is contained in:
Jake VanderPlas 2024-03-05 11:20:38 -08:00
parent 32fec820ed
commit d8a4ea42cc

View File

@ -14,26 +14,9 @@
"""Utilities for running JAX on Cloud TPUs via Colab."""
import textwrap
message = """
As of JAX 0.4.0, JAX only supports TPU VMs, not the older Colab TPUs.
We recommend trying Kaggle Notebooks
(https://www.kaggle.com/code, click on "New Notebook" near the top) which offer
TPU VMs. You have to create an account, log in, and verify your account to get
accelerator support.
Once you do that, there's a new "TPU 1VM v3-8" accelerator option. This gives
you a TPU notebook environment similar to Colab, but using the newer TPU VM
architecture. This should be a less buggy, more performant, and overall better
experience than the older TPU node architecture.
It is also possible to use Colab together with a self-hosted Jupyter kernel
running on a Cloud TPU VM. See
https://research.google.com/colaboratory/local-runtimes.html
for details.
"""
def setup_tpu(tpu_driver_version=None):
"""Returns an error. Do not use."""
raise RuntimeError(textwrap.dedent(message))
"""Raises an error. Do not use."""
raise RuntimeError(
"jax.tools.colab_tpu.setup_tpu() was required for older JAX versions"
" running on older generations of TPUs, and should no longer be used.")