diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 2ceab757d..c1b7c4f59 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -14,6 +14,7 @@ import os + def cloud_tpu_init(): """Automatically sets Cloud TPU topology and other env vars. @@ -63,16 +64,18 @@ def cloud_tpu_init(): # pylint: enable=import-outside-toplevel # Based on https://github.com/tensorflow/tensorflow/pull/40317 - gce_metadata_endpoint = 'http://' + os.environ.get('GCE_METADATA_IP', - 'metadata.google.internal') + gce_metadata_endpoint = 'http://' + os.environ.get( + 'GCE_METADATA_IP', 'metadata.google.internal') + def get_metadata(key): return requests.get( f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', - headers={'Metadata-Flavor': 'Google'}).text + headers={ + 'Metadata-Flavor': 'Google' + }).text worker_id = get_metadata('agent-worker-number') accelerator_type = get_metadata('accelerator-type') - worker_network_endpoints = get_metadata('worker-network-endpoints') accelerator_type_to_host_bounds = { 'v2-8': '1,1,1', @@ -88,27 +91,12 @@ def cloud_tpu_init(): 'v3-512': '8,8,1', 'v3-1024': '8,16,1', 'v3-2048': '16,16,1', - 'v4-8': '1,1,1', - 'v4-16': '1,1,2', - 'v4-32': '1,1,4', - 'v4-64': '1,2,4', - 'v4-128': '2,2,4', - 'v4-256': '2,2,8', - 'v4-512': '2,4,8', - 'v4-1024': '4,4,8', - 'v4-2048': '4,4,16', - 'v4-4096': '4,8,16', } os.environ['CLOUD_TPU_TASK_ID'] = worker_id - os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' - os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ - accelerator_type] - os.environ['TPU_MESH_CONTROLLER_ADDRESS'] = worker_network_endpoints.split( - ',')[0].split(':')[2] + ':8476' - os.environ['TPU_MESH_CONTROLLER_PORT'] = '8476' - if (not os.environ.get('TPU_TOPOLOGY_WRAP', None) - and 'v4' in accelerator_type - and accelerator_type not in ['v4-8', 'v4-16', 'v4-32', 'v4-64']): - os.environ['TPU_TOPOLOGY_WRAP'] = 'true,true,true' + # If v4 TPU don't set any topology related flags, libtpu will set these values. + if not accelerator_type.startswith('v4-'): + os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' + os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ + accelerator_type]