rocm_jax/jax/_src/cloud_tpu_init.py
jax authors b768b659e3 Fix TPU_ML_PLATFORM and TPU_ML_PLATFORM_VERSION environment variables.
The previous code was overwriting the environment variables, which could cause problems if they were already set. The new code uses os.environ.setdefault, which will only set the environment variable if it is not already set.

PiperOrigin-RevId: 681608895
2024-10-02 15:14:50 -07:00

100 lines
3.3 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from jax import version
from jax._src import config
from jax._src import hardware_utils
running_in_cloud_tpu_vm: bool = False
def maybe_import_libtpu():
try:
# pylint: disable=import-outside-toplevel
# pytype: disable=import-error
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
except ImportError:
return None
else:
return libtpu
def get_tpu_library_path() -> str | None:
path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None and os.path.isfile(path_from_env):
return path_from_env
libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
return libtpu_module.get_library_path()
return None
def jax_force_tpu_init() -> bool:
return 'JAX_FORCE_TPU_INIT' in os.environ
def cloud_tpu_init() -> None:
"""Automatically sets Cloud TPU topology and other env vars.
**This must be called before the TPU runtime is loaded, which happens as soon
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
is imported.**
Safe to call in non-Cloud TPU environments.
Some of these environment variables are used to tell the TPU runtime what kind
of mesh topology to use. It assumes a single-host topology by default, so we
manually set them here to default to the full pod slice if applicable.
This will not set any env vars if a single topology-related env var is already
set.
"""
global running_in_cloud_tpu_vm
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
libtpu_path = get_tpu_library_path()
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
return
running_in_cloud_tpu_vm = True
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
# this makes tensorstore serialization work better on TPU
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
# If the JAX_PLATFORMS env variable isn't set, config.jax_platforms defaults
# to None. In this case, we set it to 'tpu,cpu' to ensure that JAX uses the
# TPU backend.
if config.jax_platforms.value is None:
config.update('jax_platforms', 'tpu,cpu')
if config.jax_pjrt_client_create_options.value is None:
config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
)