From e2aa9391478ccebeca2cb17ac7da0ad40039daea Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 11 Oct 2022 17:49:10 -0700 Subject: [PATCH] Update Colab TPU driver version --- CHANGELOG.md | 2 ++ jax/tools/colab_tpu.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e35a3cc2..bcab0b127 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> ## jax 0.3.23 +* Changes + * Update Colab TPU driver version for new jaxlib release. ## jaxlib 0.3.23 diff --git a/jax/tools/colab_tpu.py b/jax/tools/colab_tpu.py index 4540de658..35278a2e5 100644 --- a/jax/tools/colab_tpu.py +++ b/jax/tools/colab_tpu.py @@ -22,7 +22,7 @@ from jax.config import config TPU_DRIVER_MODE = 0 -def setup_tpu(tpu_driver_version='tpu_driver-0.2'): +def setup_tpu(tpu_driver_version='tpu_driver_20221011'): """Sets up Colab to run on TPU. Note: make sure the Colab Runtime is set to Accelerator: TPU. @@ -30,8 +30,7 @@ def setup_tpu(tpu_driver_version='tpu_driver-0.2'): Args ---- tpu_driver_version : (str) specify the version identifier for the tpu driver. - Defaults to "tpu_driver-0.2", which can be used with jaxlib 0.3.20. Set to - "tpu_driver_nightly" to use the nightly tpu driver build. + Set to "tpu_driver_nightly" to use the nightly tpu driver build. """ global TPU_DRIVER_MODE