diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 82219886b..0539e4253 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -15,6 +15,7 @@ import datetime import os import re +import warnings from jax import version from jax._src import config from jax._src import hardware_utils @@ -72,7 +73,19 @@ def cloud_tpu_init() -> None: # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. libtpu_path = get_tpu_library_path() - num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] + num_tpu_chips, tpu_id = hardware_utils.num_available_tpu_chips_and_device_id() + if ( + tpu_id is not None + and tpu_id >= hardware_utils.TpuVersion.v5e + and not hardware_utils.transparent_hugepages_enabled() + ): + warnings.warn( + 'Transparent hugepages are not enabled. TPU runtime startup and' + ' shutdown time should be significantly improved on TPU v5e and newer.' + ' If not already set, you may need to enable transparent hugepages in' + ' your VM image (sudo sh -c "echo always >' + ' /sys/kernel/mm/transparent_hugepage/enabled")' + ) if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 81ef07a71..84ad9edf9 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -12,25 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import enum import os import pathlib import glob _GOOGLE_PCI_VENDOR_ID = '0x1ae0' -_TPU_PCI_DEVICE_IDS = [ - # TPU v2, v3 - '0x0027', - # No public name (plc) - '0x0056', - # TPU v4 - '0x005e', - # TPU v5p - '0x0062', - # TPU v5e - '0x0063', - # TPU v6e - '0x006f', -] _NVIDIA_GPU_DEVICES = [ '/dev/nvidia0', @@ -38,10 +25,36 @@ _NVIDIA_GPU_DEVICES = [ '/dev/dxg', # WSL2 ] + +class TpuVersion(enum.IntEnum): + # TPU v2, v3 + v2 = 0 + v3 = 1 + # No public name (plc) + plc = 2 + # TPU v4 + v4 = 3 + # TPU v5p + v5p = 4 + # TPU v5e + v5e = 5 + # TPU v6e + v6e = 6 + + +_TPU_PCI_DEVICE_IDS = { + '0x0027': TpuVersion.v3, + '0x0056': TpuVersion.plc, + '0x005e': TpuVersion.v4, + '0x0062': TpuVersion.v5p, + '0x0063': TpuVersion.v5e, + '0x006f': TpuVersion.v6e, +} + def num_available_tpu_chips_and_device_id(): """Returns the device id and number of TPU chips attached through PCI.""" num_chips = 0 - device_id = '' + tpu_version = None for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'): vendor_id = pathlib.Path(vendor_path).read_text().strip() if vendor_id != _GOOGLE_PCI_VENDOR_ID: @@ -50,12 +63,20 @@ def num_available_tpu_chips_and_device_id(): device_path = os.path.join(os.path.dirname(vendor_path), 'device') device_id = pathlib.Path(device_path).read_text().strip() if device_id in _TPU_PCI_DEVICE_IDS: + tpu_version = _TPU_PCI_DEVICE_IDS[device_id] num_chips += 1 - return num_chips, device_id + return num_chips, tpu_version def has_visible_nvidia_gpu() -> bool: """True if there's a visible nvidia gpu available on device, False otherwise.""" return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES) + + +def transparent_hugepages_enabled() -> bool: + # See https://docs.kernel.org/admin-guide/mm/transhuge.html for more + # information about transparent huge pages. + path = pathlib.Path('/sys/kernel/mm/transparent_hugepage/enabled') + return path.exists() and path.read_text().strip() == '[always] madvise never' diff --git a/pyproject.toml b/pyproject.toml index e32b14a89..a1b9e7dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ filterwarnings = [ # https://github.com/protocolbuffers/protobuf/issues/12186#issuecomment-1745679358 "ignore:Type google\\._upb\\._message\\.(Scalar|Message)MapContainer uses PyType_Spec with a metaclass that has custom tp_new\\. This is deprecated and will no longer be allowed in Python 3\\.14\\.:DeprecationWarning", + # TODO(b/401588349): Remove this once transparent hugepages are enabled. + "ignore:Transparent hugepages", + # NOTE: this is probably not where you want to add code to suppress a # warning. Only pytest tests look at this list, whereas Bazel tests also # check for warnings and do not check this list. Most likely, you should