mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Enable runtime uptime telemetry for JAX on Cloud TPU, only if the flag is not set by the user explicitly. Otherwise, prefer the user preference.
PiperOrigin-RevId: 648812558
This commit is contained in:
parent
242c993cee
commit
d9bd3587c5
@ -80,7 +80,7 @@ def cloud_tpu_init() -> None:
|
||||
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
|
||||
os.environ['ENABLE_RUNTIME_UPTIME_TELEMETRY'] = '1'
|
||||
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
||||
if hardware_utils.tpu_enhanced_barrier_supported():
|
||||
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user