mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove the "No GPU/TPU found" warning.
Instead, add a lightweight test for NVIDIA GPUs and Google TPUs. Warn only if we suspect either is present but JAX is not using them.
This commit is contained in:
parent
09df912c0a
commit
210fab1aae
@ -22,10 +22,12 @@ XLA. There are also a handful of related casting utilities.
|
||||
from collections.abc import Mapping
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform as py_platform
|
||||
import pkgutil
|
||||
import sys
|
||||
@ -36,6 +38,7 @@ import warnings
|
||||
from jax._src import distributed
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import config
|
||||
import jax._src.lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
@ -133,7 +136,7 @@ class BackendRegistration:
|
||||
_backend_factories: dict[str, BackendRegistration] = {}
|
||||
_default_backend: Optional[xla_client.Client] = None
|
||||
_backends : dict[str, xla_client.Client] = {}
|
||||
_backends_errors : dict[str, str] = {}
|
||||
_backend_errors : dict[str, str] = {}
|
||||
_backend_lock = threading.Lock()
|
||||
_plugins_registered: bool = False
|
||||
_plugin_lock = threading.Lock()
|
||||
@ -465,7 +468,7 @@ def is_gpu(platform):
|
||||
|
||||
def backends() -> dict[str, xla_client.Client]:
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _backend_errors
|
||||
global _default_backend
|
||||
global _plugins_registered
|
||||
|
||||
@ -515,7 +518,7 @@ def backends() -> dict[str, xla_client.Client]:
|
||||
except Exception as err:
|
||||
err_msg = f"Unable to initialize backend '{platform}': {err}"
|
||||
if fail_quietly:
|
||||
_backends_errors[platform] = str(err)
|
||||
_backend_errors[platform] = str(err)
|
||||
logger.info(err_msg)
|
||||
else:
|
||||
if config.jax_platforms:
|
||||
@ -525,25 +528,82 @@ def backends() -> dict[str, xla_client.Client]:
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
assert _default_backend is not None
|
||||
# We don't warn about falling back to CPU on Mac OS, because we don't
|
||||
# support anything else there at the moment and warning would be pointless.
|
||||
if (py_platform.system() != "Darwin" and
|
||||
_default_backend.platform == "cpu" and
|
||||
_PLATFORM_NAME.value != 'cpu'):
|
||||
logger.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
if not config.jax_platforms:
|
||||
_suggest_missing_backends()
|
||||
return _backends
|
||||
|
||||
|
||||
# Code to suggest plugins that should be installed.
|
||||
#
|
||||
# Plugin vendors are welcome to add code to this list, assuming there's a
|
||||
# lightweight way to determine if hardware is present without requiring
|
||||
# the relevant plugin be installed.
|
||||
|
||||
_GOOGLE_PCI_VENDOR_ID = '0x1ae0'
|
||||
_TPU_PCI_DEVICE_IDS = [
|
||||
# TPU v2, v3
|
||||
'0x0027',
|
||||
# TPU v4
|
||||
'0x005e',
|
||||
# TPU v5e
|
||||
'0x0063',
|
||||
# Testing only
|
||||
'0x0056',
|
||||
'0x0062',
|
||||
]
|
||||
|
||||
def _num_available_tpu_chips() -> int:
|
||||
"""Returns the number of TPU chips attached through PCI."""
|
||||
num_chips = 0
|
||||
for vendor_path in glob.glob('/sys/bus/pci/devices/*/vendor'):
|
||||
vendor_id = pathlib.Path(vendor_path).read_text().strip()
|
||||
if vendor_id != _GOOGLE_PCI_VENDOR_ID:
|
||||
continue
|
||||
|
||||
device_path = os.path.join(os.path.dirname(vendor_path), 'device')
|
||||
device_id = pathlib.Path(device_path).read_text().strip()
|
||||
if device_id in _TPU_PCI_DEVICE_IDS:
|
||||
num_chips += 1
|
||||
|
||||
return num_chips
|
||||
|
||||
def _suggest_missing_backends():
|
||||
if py_platform.system() != "Linux":
|
||||
# If you're not using Linux (or WSL2), we don't have any suggestions at the
|
||||
# moment.
|
||||
return
|
||||
|
||||
assert _default_backend is not None
|
||||
default_platform = _default_backend.platform
|
||||
nvidia_gpu_devices = [
|
||||
"/dev/nvidia0",
|
||||
"/dev/dxg", # WSL2
|
||||
]
|
||||
if ("cuda" not in _backends and
|
||||
any(os.path.exists(d) for d in nvidia_gpu_devices)):
|
||||
if jax._src.lib.gpu_solver is not None:
|
||||
err = _backend_errors["cuda"]
|
||||
logger.warning(f"CUDA backend failed to initialize: {err} (Set "
|
||||
"TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)")
|
||||
else:
|
||||
logger.warning("An NVIDIA GPU may be present on this machine, but a "
|
||||
"CUDA-enabled jaxlib is not installed. Falling back to "
|
||||
f"{default_platform}.")
|
||||
elif "tpu" not in _backends and _num_available_tpu_chips() > 0:
|
||||
logger.warning("A Google TPU may be present on this machine, but either a "
|
||||
"TPU-enabled jaxlib or libtpu is not installed. Falling "
|
||||
f"back to {default_platform}.")
|
||||
|
||||
|
||||
def _clear_backends() -> None:
|
||||
global _backends
|
||||
global _backends_errors
|
||||
global _backend_errors
|
||||
global _default_backend
|
||||
|
||||
logger.info("Clearing JAX backend caches.")
|
||||
with _backend_lock:
|
||||
_backends = {}
|
||||
_backends_errors = {}
|
||||
_backend_errors = {}
|
||||
_default_backend = None
|
||||
|
||||
get_backend.cache_clear()
|
||||
@ -590,9 +650,9 @@ def _get_backend_uncached(
|
||||
platform = canonicalize_platform(platform)
|
||||
backend = bs.get(platform, None)
|
||||
if backend is None:
|
||||
if platform in _backends_errors:
|
||||
if platform in _backend_errors:
|
||||
raise RuntimeError(f"Backend '{platform}' failed to initialize: "
|
||||
f"{_backends_errors[platform]}. "
|
||||
f"{_backend_errors[platform]}. "
|
||||
f'Available backends are {list(bs)}')
|
||||
raise RuntimeError(f"Unknown backend {platform}")
|
||||
return backend
|
||||
|
@ -28,7 +28,6 @@ import itertools as it
|
||||
import operator
|
||||
import operator as op
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
@ -10270,25 +10269,16 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
class BackendsTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
||||
@unittest.skipIf(platform.system() == "Darwin",
|
||||
"Warning doesn't apply on Mac")
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_cpu_warning_suppression(self):
|
||||
warning_expected = (
|
||||
"import jax; "
|
||||
"jax.numpy.arange(10)")
|
||||
def test_no_backend_warning_on_cpu_if_platform_specified(self):
|
||||
warning_not_expected = (
|
||||
"import jax; "
|
||||
"jax.config.update('jax_platform_name', 'cpu'); "
|
||||
"jax.numpy.arange(10)")
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', warning_expected],
|
||||
check=True, capture_output=True)
|
||||
assert "No GPU/TPU found" in result.stderr.decode()
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', warning_not_expected],
|
||||
check=True, capture_output=True)
|
||||
assert "No GPU/TPU found" not in result.stderr.decode()
|
||||
assert "may be present" not in result.stderr.decode()
|
||||
|
||||
|
||||
class CleanupTest(jtu.JaxTestCase):
|
||||
|
@ -288,18 +288,18 @@ class GetBackendTest(jtu.JaxTestCase):
|
||||
|
||||
def _save_backend_state(self):
|
||||
self._orig_backends = xb._backends
|
||||
self._orig_backends_errors = xb._backends_errors
|
||||
self._orig_backend_errors = xb._backend_errors
|
||||
self._orig_default_backend = xb._default_backend
|
||||
|
||||
def _reset_backend_state(self):
|
||||
xb._backends = {}
|
||||
xb._backends_errors = {}
|
||||
xb._backend_errors = {}
|
||||
xb._default_backend = None
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
def _restore_backend_state(self):
|
||||
xb._backends = self._orig_backends
|
||||
xb._backends_errors = self._orig_backends_errors
|
||||
xb._backend_errors = self._orig_backend_errors
|
||||
xb._default_backend = self._orig_default_backend
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user