adding os env to track JAX platform

This commit is contained in:
Shiva Shahrokhi 2022-06-14 21:21:34 +00:00
parent cd565f8f41
commit 88f1b9fae7

View File

@ -48,6 +48,7 @@ def cloud_tpu_init():
libtpu.configure_library_path()
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
# If the user has set any topology-related env vars, don't set any
# automatically.