mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
On Cloud TPU, use pip-installed libtpu instead of system default if applicable.
This commit is contained in:
parent
3b84f85700
commit
ba972f0207
@ -21,9 +21,9 @@ def cloud_tpu_init():
|
||||
as JAX's C++ backend is loaded! I.e. call this before xla_bridge or xla_client
|
||||
is imported.**
|
||||
|
||||
These environment variables are used to tell the TPU runtime what kind of mesh
|
||||
topology to use. It assumes a single-host topology by default, so we manually
|
||||
set them here to default to the full pod slice if applicable.
|
||||
Some of these environment variables are used to tell the TPU runtime what kind
|
||||
of mesh topology to use. It assumes a single-host topology by default, so we
|
||||
manually set them here to default to the full pod slice if applicable.
|
||||
|
||||
This will not set any env vars if a single topology-related env var is already
|
||||
set.
|
||||
@ -31,6 +31,17 @@ def cloud_tpu_init():
|
||||
if not _running_in_cloud_tpu_vm():
|
||||
return
|
||||
|
||||
# Use pip-installed libtpu if applicable, rather than system default.
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pytype: disable=import-error
|
||||
import libtpu
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
libtpu.configure_library_path()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
|
||||
# If the user has set any topology-related env vars, don't set any
|
||||
|
Loading…
x
Reference in New Issue
Block a user