Remove flags set for all v4 TPUs. Topology flags will now be set in libTPU.

Remove deprecated fields `TPU_MESH_CONTROLLER_ADDRESS` and `TPU_MESH_CONTROLLER_PORT`.

PiperOrigin-RevId: 442663216
This commit is contained in:
Gain Hagenau 2022-04-18 16:39:01 -07:00 committed by jax authors
parent 4d68f4efb2
commit 59d8b8d6b2

View File

@ -14,6 +14,7 @@
import os
def cloud_tpu_init():
"""Automatically sets Cloud TPU topology and other env vars.
@ -63,16 +64,18 @@ def cloud_tpu_init():
# pylint: enable=import-outside-toplevel
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get('GCE_METADATA_IP',
'metadata.google.internal')
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')
def get_metadata(key):
return requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'}).text
headers={
'Metadata-Flavor': 'Google'
}).text
worker_id = get_metadata('agent-worker-number')
accelerator_type = get_metadata('accelerator-type')
worker_network_endpoints = get_metadata('worker-network-endpoints')
accelerator_type_to_host_bounds = {
'v2-8': '1,1,1',
@ -88,27 +91,12 @@ def cloud_tpu_init():
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
'v4-8': '1,1,1',
'v4-16': '1,1,2',
'v4-32': '1,1,4',
'v4-64': '1,2,4',
'v4-128': '2,2,4',
'v4-256': '2,2,8',
'v4-512': '2,4,8',
'v4-1024': '4,4,8',
'v4-2048': '4,4,16',
'v4-4096': '4,8,16',
}
os.environ['CLOUD_TPU_TASK_ID'] = worker_id
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]
os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split(
',')[0].split(':')[2] + ':8476'
os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476'
if (not os.environ.get('TPU_TOPOLOGY_WRAP', None)
and 'v4' in accelerator_type
and accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']):
os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true'
# If v4 TPU don't set any topology related flags, libtpu will set these values.
if not accelerator_type.startswith('v4-'):
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]