mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix TPU_ML_PLATFORM and TPU_ML_PLATFORM_VERSION environment variables.
The previous code was overwriting the environment variables, which could cause problems if they were already set. The new code uses os.environ.setdefault, which will only set the environment variable if it is not already set. PiperOrigin-RevId: 681608895
This commit is contained in:
parent
81d2fbe094
commit
b768b659e3
@ -77,8 +77,8 @@ def cloud_tpu_init() -> None:
|
||||
running_in_cloud_tpu_vm = True
|
||||
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
|
||||
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
|
||||
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
|
||||
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
||||
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user