Fix KeyError recently introduced in cloud_tpu_init.py

This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889
This commit is contained in:
Skye Wanderman-Milne 2024-11-20 17:46:06 +00:00
parent 439d34da15
commit 6222592625

View File

@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']:
if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ.get('LIBTPU_INIT_ARGS', ''):
os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true'
# this makes tensorstore serialization work better on TPU