mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
making sure enhanced barrier only turns on when there is a supported TPU available.
This commit is contained in:
parent
41531123f4
commit
65f3e4fffd
@ -342,6 +342,9 @@ pytype_strict_library(
|
||||
pytype_strict_library(
|
||||
name = "cloud_tpu_init",
|
||||
srcs = ["_src/cloud_tpu_init.py"],
|
||||
deps = [
|
||||
":hardware_utils",
|
||||
]
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
@ -470,6 +473,11 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "hardware_utils",
|
||||
srcs = ["_src/hardware_utils.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "lax_reference",
|
||||
srcs = ["_src/lax_reference.py"],
|
||||
@ -831,6 +839,7 @@ pytype_strict_library(
|
||||
deps = [
|
||||
":cloud_tpu_init",
|
||||
":config",
|
||||
":hardware_utils",
|
||||
":traceback_util",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from jax._src import hardware_utils
|
||||
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
@ -66,6 +67,8 @@ def cloud_tpu_init() -> None:
|
||||
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
|
||||
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
|
||||
os.environ['TPU_ML_PLATFORM'] = 'JAX'
|
||||
if hardware_utils.tpu_enhanced_barrier_supported():
|
||||
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true"
|
||||
|
||||
# TODO(skyewm): remove this warning at some point, say around Sept 2023.
|
||||
use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None)
|
||||
|
59
jax/_src/hardware_utils.py
Normal file
59
jax/_src/hardware_utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import glob
|
||||
|
||||
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
|
||||
_TPU_PCI_DEVICE_IDS = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
# TPU v5e
|
||||
'0x0063',
|
||||
# Testing only
|
||||
'0x0056',
|
||||
'0x0062',
|
||||
]
|
||||
|
||||
_TPU_ENHANCED_BARRIER_SUPPORTED = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
]
|
||||
|
||||
def num_available_tpu_chips_and_device_id():
|
||||
"""Returns the device id and number of TPU chips attached through PCI."""
|
||||
num_chips = 0
|
||||
device_id = ''
|
||||
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
|
||||
vendor_id = pathlib.Path(vendor_path).read_text().strip()
|
||||
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
|
||||
continue
|
||||
|
||||
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
|
||||
device_id = pathlib.Path(device_path).read_text().strip()
|
||||
if device_id in _TPU_PCI_DEVICE_IDS:
|
||||
num_chips += 1
|
||||
|
||||
return num_chips, device_id
|
||||
|
||||
|
||||
def tpu_enhanced_barrier_supported() -> bool:
|
||||
"""Returns if tpu_enhanced_barrier flag is supported on this TPU version."""
|
||||
_, device_id = num_available_tpu_chips_and_device_id()
|
||||
return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED
|
@ -25,12 +25,10 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
from functools import lru_cache, partial
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import pkgutil
|
||||
import platform as py_platform
|
||||
import sys
|
||||
@ -42,6 +40,7 @@ from jax._src import config
|
||||
from jax._src import distributed
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src import hardware_utils
|
||||
from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
@ -61,7 +60,6 @@ except ImportError as e:
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
XlaBackend = xla_client.Client
|
||||
|
||||
|
||||
@ -670,41 +668,12 @@ def backends() -> dict[str, xla_client.Client]:
|
||||
_suggest_missing_backends()
|
||||
return _backends
|
||||
|
||||
|
||||
# Code to suggest plugins that should be installed.
|
||||
#
|
||||
# Plugin vendors are welcome to add code to this list, assuming there's a
|
||||
# lightweight way to determine if hardware is present without requiring
|
||||
# the relevant plugin be installed.
|
||||
|
||||
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
|
||||
_TPU_PCI_DEVICE_IDS = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
# TPU v5e
|
||||
'0x0063',
|
||||
# Testing only
|
||||
'0x0056',
|
||||
'0x0062',
|
||||
]
|
||||
|
||||
def _num_available_tpu_chips() -> int:
|
||||
"""Returns the number of TPU chips attached through PCI."""
|
||||
num_chips = 0
|
||||
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
|
||||
vendor_id = pathlib.Path(vendor_path).read_text().strip()
|
||||
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
|
||||
continue
|
||||
|
||||
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
|
||||
device_id = pathlib.Path(device_path).read_text().strip()
|
||||
if device_id in _TPU_PCI_DEVICE_IDS:
|
||||
num_chips += 1
|
||||
|
||||
return num_chips
|
||||
|
||||
def _suggest_missing_backends():
|
||||
if py_platform.system() != "Linux":
|
||||
# If you're not using Linux (or WSL2), we don't have any suggestions at the
|
||||
@ -727,7 +696,7 @@ def _suggest_missing_backends():
|
||||
logger.warning("An NVIDIA GPU may be present on this machine, but a "
|
||||
"CUDA-enabled jaxlib is not installed. Falling back to "
|
||||
f"{default_platform}.")
|
||||
elif "tpu" not in _backends and _num_available_tpu_chips() > 0:
|
||||
elif "tpu" not in _backends and hardware_utils.num_available_tpu_chips_and_device_id()[0] > 0:
|
||||
logger.warning("A Google TPU may be present on this machine, but either a "
|
||||
"TPU-enabled jaxlib or libtpu is not installed. Falling "
|
||||
f"back to {default_platform}.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user