Set JAX_PLATFORMS=tpu,cpu on TPUs

This commit is contained in:
Ran Ran 2022-10-03 22:52:03 +00:00
parent 015a12e155
commit 78a7e161bb

View File

@ -48,6 +48,7 @@ def cloud_tpu_init():
libtpu.configure_library_path()
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'
# If the user has set any topology-related env vars, don't set any