diff --git a/jax/BUILD b/jax/BUILD index 343ab07e6..9a98cf4fa 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -390,6 +390,7 @@ pytype_strict_library( srcs = ["_src/cloud_tpu_init.py"], deps = [ ":hardware_utils", + ":version", ], ) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 2cccbd301..68fe25562 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -14,6 +14,7 @@ import os from jax._src import hardware_utils +from jax import version running_in_cloud_tpu_vm: bool = False @@ -65,6 +66,7 @@ def cloud_tpu_init() -> None: os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' + os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ if hardware_utils.tpu_enhanced_barrier_supported(): os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"