Fix TPU_ML_PLATFORM and TPU_ML_PLATFORM_VERSION environment variables.

The previous code was overwriting the environment variables, which could cause problems if they were already set. The new code uses os.environ.setdefault, which will only set the environment variable if it is not already set.

PiperOrigin-RevId: 681608895
This commit is contained in:
jax authors 2024-10-02 15:14:08 -07:00
parent 81d2fbe094
commit b768b659e3

View File

@ -77,8 +77,8 @@ def cloud_tpu_init() -> None:
running_in_cloud_tpu_vm = True
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"