mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add JAX version to TPU_ML_PLATFORM_VERSION environment variable.
This will allow us to track the JAX version that is being used on Cloud TPUs PiperOrigin-RevId: 637025132
This commit is contained in:
parent
683ca2cd40
commit
93170d9c80
@ -390,6 +390,7 @@ pytype_strict_library(
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
deps = [
|
||||
":hardware_utils",
|
||||
":version",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
from jax._src import hardware_utils
|
||||
from jax import version
|
||||
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
@ -65,6 +66,7 @@ def cloud_tpu_init() -> None:
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__
|
||||
if hardware_utils.tpu_enhanced_barrier_supported():
|
||||
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user