diff --git a/jax/BUILD b/jax/BUILD index 9bd928d04..d2029a0d6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -342,6 +342,9 @@ pytype_strict_library( pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], + deps = [ + ":hardware_utils", + ] ) pytype_strict_library( @@ -470,6 +473,11 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "hardware_utils", + srcs = ["_src/hardware_utils.py"], +) + pytype_library( name = "lax_reference", srcs = ["_src/lax_reference.py"], @@ -831,6 +839,7 @@ pytype_strict_library( deps = [ ":cloud_tpu_init", ":config", + ":hardware_utils", ":traceback_util", ":util", "//jax/_src/lib", diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 419576941..d6547a367 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -14,6 +14,7 @@ import os import warnings +from jax._src import hardware_utils running_in_cloud_tpu_vm: bool = False @@ -66,6 +67,8 @@ def cloud_tpu_init() -> None: os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' + if hardware_utils.tpu_enhanced_barrier_supported(): + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" # TODO(skyewm): remove this warning at some point, say around Sept 2023. use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None) diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py new file mode 100644 index 000000000..3f6686d05 --- /dev/null +++ b/jax/_src/hardware_utils.py @@ -0,0 +1,59 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +import glob + +_GOOGLE_PCI_VENDOR_ID = '0x1ae0' +_TPU_PCI_DEVICE_IDS = [ + # TPU v2, v3 + '0x0027', + # TPU v4 + '0x005e', + # TPU v5e + '0x0063', + # Testing only + '0x0056', + '0x0062', +] + +_TPU_ENHANCED_BARRIER_SUPPORTED = [ + # TPU v2, v3 + '0x0027', + # TPU v4 + '0x005e', +] + +def num_available_tpu_chips_and_device_id(): + """Returns the device id and number of TPU chips attached through PCI.""" + num_chips = 0 + device_id = '' + for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'): + vendor_id = pathlib.Path(vendor_path).read_text().strip() + if vendor_id != _GOOGLE_PCI_VENDOR_ID: + continue + + device_path = os.path.join(os.path.dirname(vendor_path), 'device') + device_id = pathlib.Path(device_path).read_text().strip() + if device_id in _TPU_PCI_DEVICE_IDS: + num_chips += 1 + + return num_chips, device_id + + +def tpu_enhanced_barrier_supported() -> bool: + """Returns if tpu_enhanced_barrier flag is supported on this TPU version.""" + _, device_id = num_available_tpu_chips_and_device_id() + return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 7977f6329..3bdbb343d 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -25,12 +25,10 @@ from __future__ import annotations from collections.abc import Mapping import dataclasses from functools import lru_cache, partial -import glob import importlib import json import logging import os -import pathlib import pkgutil import platform as py_platform import sys @@ -42,6 +40,7 @@ from jax._src import config from jax._src import distributed from jax._src import traceback_util from jax._src import util +from jax._src import hardware_utils from jax._src.cloud_tpu_init import maybe_import_libtpu from jax._src.lib import cuda_versions from jax._src.lib import xla_client @@ -61,7 +60,6 @@ except ImportError as e: traceback_util.register_exclusion(__file__) - XlaBackend = xla_client.Client @@ -670,41 +668,12 @@ def backends() -> dict[str, xla_client.Client]: _suggest_missing_backends() return _backends - # Code to suggest plugins that should be installed. # # Plugin vendors are welcome to add code to this list, assuming there's a # lightweight way to determine if hardware is present without requiring # the relevant plugin be installed. -_GOOGLE_PCI_VENDOR_ID = '0x1ae0' -_TPU_PCI_DEVICE_IDS = [ - # TPU v2, v3 - '0x0027', - # TPU v4 - '0x005e', - # TPU v5e - '0x0063', - # Testing only - '0x0056', - '0x0062', -] - -def _num_available_tpu_chips() -> int: - """Returns the number of TPU chips attached through PCI.""" - num_chips = 0 - for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'): - vendor_id = pathlib.Path(vendor_path).read_text().strip() - if vendor_id != _GOOGLE_PCI_VENDOR_ID: - continue - - device_path = os.path.join(os.path.dirname(vendor_path), 'device') - device_id = pathlib.Path(device_path).read_text().strip() - if device_id in _TPU_PCI_DEVICE_IDS: - num_chips += 1 - - return num_chips - def _suggest_missing_backends(): if py_platform.system() != "Linux": # If you're not using Linux (or WSL2), we don't have any suggestions at the @@ -727,7 +696,7 @@ def _suggest_missing_backends(): logger.warning("An NVIDIA GPU may be present on this machine, but a " "CUDA-enabled jaxlib is not installed. Falling back to " f"{default_platform}.") - elif "tpu" not in _backends and _num_available_tpu_chips() > 0: + elif "tpu" not in _backends and hardware_utils.num_available_tpu_chips_and_device_id()[0] > 0: logger.warning("A Google TPU may be present on this machine, but either a " "TPU-enabled jaxlib or libtpu is not installed. Falling " f"back to {default_platform}.")