Add JAX version to TPU_ML_PLATFORM_VERSION environment variable.

This will allow us to track the JAX version that is being used on Cloud TPUs

PiperOrigin-RevId: 637025132
This commit is contained in:
jax authors 2024-05-24 13:55:34 -07:00 committed by jax authors
parent 683ca2cd40
commit 93170d9c80
2 changed files with 3 additions and 0 deletions

View File

@ -390,6 +390,7 @@ pytype_strict_library(
srcs = ["_src/cloud_tpu_init.py"],
deps = [
":hardware_utils",
":version",
],
)

View File

@ -14,6 +14,7 @@
import os
from jax._src import hardware_utils
from jax import version
running_in_cloud_tpu_vm: bool = False
@ -65,6 +66,7 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
if hardware_utils.tpu_enhanced_barrier_supported():
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"