Remove the "No GPU/TPU found" warning.

Instead, add a lightweight test for NVIDIA GPUs and Google TPUs. Warn
only if we suspect either is present but JAX is not using them.
This commit is contained in:
Peter Hawkins 2023-09-23 20:06:19 +00:00
parent 09df912c0a
commit 210fab1aae
3 changed files with 79 additions and 29 deletions

View File

@ -22,10 +22,12 @@ XLA. There are also a handful of related casting utilities.
from collections.abc import Mapping
import dataclasses
from functools import partial, lru_cache
import glob
import importlib
import json
import logging
import os
import pathlib
import platform as py_platform
import pkgutil
import sys
@ -36,6 +38,7 @@ import warnings
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import config
import jax._src.lib
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
@ -133,7 +136,7 @@ class BackendRegistration:
_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: Optional[xla_client.Client] = None
_backends : dict[str, xla_client.Client] = {}
_backends_errors : dict[str, str] = {}
_backend_errors : dict[str, str] = {}
_backend_lock = threading.Lock()
_plugins_registered: bool = False
_plugin_lock = threading.Lock()
@ -465,7 +468,7 @@ def is_gpu(platform):
def backends() -> dict[str, xla_client.Client]:
global _backends
global _backends_errors
global _backend_errors
global _default_backend
global _plugins_registered
@ -515,7 +518,7 @@ def backends() -> dict[str, xla_client.Client]:
except Exception as err:
err_msg = f"Unable to initialize backend '{platform}': {err}"
if fail_quietly:
_backends_errors[platform] = str(err)
_backend_errors[platform] = str(err)
logger.info(err_msg)
else:
if config.jax_platforms:
@ -525,25 +528,82 @@ def backends() -> dict[str, xla_client.Client]:
raise RuntimeError(err_msg)
assert _default_backend is not None
# We don't warn about falling back to CPU on Mac OS, because we don't
# support anything else there at the moment and warning would be pointless.
if (py_platform.system() != "Darwin" and
_default_backend.platform == "cpu" and
_PLATFORM_NAME.value != 'cpu'):
logger.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
if not config.jax_platforms:
_suggest_missing_backends()
return _backends
# Code to suggest plugins that should be installed.
#
# Plugin vendors are welcome to add code to this list, assuming there's a
# lightweight way to determine if hardware is present without requiring
# the relevant plugin be installed.
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
_TPU_PCI_DEVICE_IDS = [
# TPU v2, v3
'0x0027',
# TPU v4
'0x005e',
# TPU v5e
'0x0063',
# Testing only
'0x0056',
'0x0062',
]
def _num_available_tpu_chips() -> int:
"""Returns the number of TPU chips attached through PCI."""
num_chips = 0
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
vendor_id = pathlib.Path(vendor_path).read_text().strip()
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
continue
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
device_id = pathlib.Path(device_path).read_text().strip()
if device_id in _TPU_PCI_DEVICE_IDS:
num_chips += 1
return num_chips
def _suggest_missing_backends():
if py_platform.system() != "Linux":
# If you're not using Linux (or WSL2), we don't have any suggestions at the
# moment.
return
assert _default_backend is not None
default_platform = _default_backend.platform
nvidia_gpu_devices = [
"/dev/nvidia0",
"/dev/dxg", # WSL2
]
if ("cuda" not in _backends and
any(os.path.exists(d) for d in nvidia_gpu_devices)):
if jax._src.lib.gpu_solver is not None:
err = _backend_errors["cuda"]
logger.warning(f"CUDA backend failed to initialize: {err} (Set "
"TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)")
else:
logger.warning("An NVIDIA GPU may be present on this machine, but a "
"CUDA-enabled jaxlib is not installed. Falling back to "
f"{default_platform}.")
elif "tpu" not in _backends and _num_available_tpu_chips() > 0:
logger.warning("A Google TPU may be present on this machine, but either a "
"TPU-enabled jaxlib or libtpu is not installed. Falling "
f"back to {default_platform}.")
def _clear_backends() -> None:
global _backends
global _backends_errors
global _backend_errors
global _default_backend
logger.info("Clearing JAX backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
_backend_errors = {}
_default_backend = None
get_backend.cache_clear()
@ -590,9 +650,9 @@ def _get_backend_uncached(
platform = canonicalize_platform(platform)
backend = bs.get(platform, None)
if backend is None:
if platform in _backends_errors:
if platform in _backend_errors:
raise RuntimeError(f"Backend '{platform}' failed to initialize: "
f"{_backends_errors[platform]}. "
f"{_backend_errors[platform]}. "
f'Available backends are {list(bs)}')
raise RuntimeError(f"Unknown backend {platform}")
return backend

View File

@ -28,7 +28,6 @@ import itertools as it
import operator
import operator as op
import os
import platform
import re
import subprocess
import sys
@ -10270,25 +10269,16 @@ class NamedCallTest(jtu.JaxTestCase):
class BackendsTest(jtu.JaxTestCase):
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@unittest.skipIf(platform.system() == "Darwin",
"Warning doesn't apply on Mac")
@jtu.run_on_devices("cpu")
def test_cpu_warning_suppression(self):
warning_expected = (
"import jax; "
"jax.numpy.arange(10)")
def test_no_backend_warning_on_cpu_if_platform_specified(self):
warning_not_expected = (
"import jax; "
"jax.config.update('jax_platform_name', 'cpu'); "
"jax.numpy.arange(10)")
result = subprocess.run([sys.executable, '-c', warning_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" in result.stderr.decode()
result = subprocess.run([sys.executable, '-c', warning_not_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" not in result.stderr.decode()
assert "may be present" not in result.stderr.decode()
class CleanupTest(jtu.JaxTestCase):

View File

@ -288,18 +288,18 @@ class GetBackendTest(jtu.JaxTestCase):
def _save_backend_state(self):
self._orig_backends = xb._backends
self._orig_backends_errors = xb._backends_errors
self._orig_backend_errors = xb._backend_errors
self._orig_default_backend = xb._default_backend
def _reset_backend_state(self):
xb._backends = {}
xb._backends_errors = {}
xb._backend_errors = {}
xb._default_backend = None
xb.get_backend.cache_clear()
def _restore_backend_state(self):
xb._backends = self._orig_backends
xb._backends_errors = self._orig_backends_errors
xb._backend_errors = self._orig_backend_errors
xb._default_backend = self._orig_default_backend
xb.get_backend.cache_clear()