mirror of
https://github.com/ROCm/jax.git
synced 2025-04-28 11:36:06 +00:00

The previous code was overwriting the environment variables, which could cause problems if they were already set. The new code uses os.environ.setdefault, which will only set the environment variable if it is not already set. PiperOrigin-RevId: 681608895
100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
from jax import version
|
|
from jax._src import config
|
|
from jax._src import hardware_utils
|
|
|
|
running_in_cloud_tpu_vm: bool = False
|
|
|
|
|
|
def maybe_import_libtpu():
|
|
try:
|
|
# pylint: disable=import-outside-toplevel
|
|
# pytype: disable=import-error
|
|
import libtpu
|
|
|
|
# pytype: enable=import-error
|
|
# pylint: enable=import-outside-toplevel
|
|
except ImportError:
|
|
return None
|
|
else:
|
|
return libtpu
|
|
|
|
|
|
def get_tpu_library_path() -> str | None:
|
|
path_from_env = os.getenv("TPU_LIBRARY_PATH")
|
|
if path_from_env is not None and os.path.isfile(path_from_env):
|
|
return path_from_env
|
|
|
|
libtpu_module = maybe_import_libtpu()
|
|
if libtpu_module is not None:
|
|
return libtpu_module.get_library_path()
|
|
|
|
return None
|
|
|
|
|
|
def jax_force_tpu_init() -> bool:
|
|
return 'JAX_FORCE_TPU_INIT' in os.environ
|
|
|
|
|
|
def cloud_tpu_init() -> None:
|
|
"""Automatically sets Cloud TPU topology and other env vars.
|
|
|
|
**This must be called before the TPU runtime is loaded, which happens as soon
|
|
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
|
|
is imported.**
|
|
|
|
Safe to call in non-Cloud TPU environments.
|
|
|
|
Some of these environment variables are used to tell the TPU runtime what kind
|
|
of mesh topology to use. It assumes a single-host topology by default, so we
|
|
manually set them here to default to the full pod slice if applicable.
|
|
|
|
This will not set any env vars if a single topology-related env var is already
|
|
set.
|
|
"""
|
|
global running_in_cloud_tpu_vm
|
|
|
|
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
|
|
libtpu_path = get_tpu_library_path()
|
|
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
|
|
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
|
return
|
|
|
|
running_in_cloud_tpu_vm = True
|
|
|
|
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
|
os.environ.setdefault('TPU_ML_PLATFORM', 'JAX')
|
|
os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__)
|
|
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
|
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
|
|
|
# this makes tensorstore serialization work better on TPU
|
|
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60')
|
|
os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256')
|
|
|
|
# If the JAX_PLATFORMS env variable isn't set, config.jax_platforms defaults
|
|
# to None. In this case, we set it to 'tpu,cpu' to ensure that JAX uses the
|
|
# TPU backend.
|
|
if config.jax_platforms.value is None:
|
|
config.update('jax_platforms', 'tpu,cpu')
|
|
|
|
if config.jax_pjrt_client_create_options.value is None:
|
|
config.update(
|
|
'jax_pjrt_client_create_options',
|
|
f'ml_framework_name:JAX;ml_framework_version:{version.__version__}'
|
|
)
|