mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 14:46:07 +00:00

I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
running_in_cloud_tpu_vm: bool = False
|
|
|
|
|
|
def maybe_import_libtpu():
|
|
try:
|
|
# pylint: disable=import-outside-toplevel
|
|
# pytype: disable=import-error
|
|
import libtpu
|
|
|
|
# pytype: enable=import-error
|
|
# pylint: enable=import-outside-toplevel
|
|
except ImportError:
|
|
return None
|
|
else:
|
|
return libtpu
|
|
|
|
|
|
def jax_force_tpu_init() -> bool:
|
|
return 'JAX_FORCE_TPU_INIT' in os.environ
|
|
|
|
|
|
def cloud_tpu_init() -> None:
|
|
"""Automatically sets Cloud TPU topology and other env vars.
|
|
|
|
**This must be called before the TPU runtime is loaded, which happens as soon
|
|
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
|
|
is imported.**
|
|
|
|
Safe to call in non-Cloud TPU environments.
|
|
|
|
Some of these environment variables are used to tell the TPU runtime what kind
|
|
of mesh topology to use. It assumes a single-host topology by default, so we
|
|
manually set them here to default to the full pod slice if applicable.
|
|
|
|
This will not set any env vars if a single topology-related env var is already
|
|
set.
|
|
"""
|
|
global running_in_cloud_tpu_vm
|
|
|
|
# We assume we are in a correctly-configured Cloud TPU environment
|
|
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
|
|
# Exit early if we're not running on Cloud TPU.
|
|
libtpu_module = maybe_import_libtpu()
|
|
if libtpu_module is not None:
|
|
libtpu_module.configure_library_path()
|
|
elif not jax_force_tpu_init():
|
|
return
|
|
|
|
running_in_cloud_tpu_vm = True
|
|
|
|
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
|
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
|
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
|
|
|
if 'JAX_USE_PJRT_C_API_ON_TPU' not in os.environ:
|
|
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = 'true'
|