Enable runtime uptime telemetry for JAX on Cloud TPU, only if the flag is not set by the user explicitly. Otherwise, prefer the user preference.

PiperOrigin-RevId: 648812558
This commit is contained in:
jax authors 2024-07-02 12:48:00 -07:00 committed by jax authors
parent 242c993cee
commit d9bd3587c5

View File

@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1'
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
if hardware_utils.tpu_enhanced_barrier_supported():
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"