From 62225926253474c6e5e4b202d5c9cf3363a02a03 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 20 Nov 2024 17:46:06 +0000 Subject: [PATCH] Fix KeyError recently introduced in cloud_tpu_init.py This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889 --- jax/_src/cloud_tpu_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 8ff52bd2f..a2f137686 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']: + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''): os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU