Allow suppression of GPU warning via jax_platform_name

This commit is contained in:
Jake VanderPlas 2021-06-28 12:54:21 -07:00
parent 0d68dbd619
commit c8e571ad84
2 changed files with 28 additions and 3 deletions

View File

@ -225,13 +225,12 @@ def backends():
# we expect a RuntimeError.
logging.info("Unable to initialize backend '%s': %s" % (name, err))
continue
if _default_backend.platform == "cpu":
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
logging.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
def get_backend(platform=None):
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how

View File

@ -20,8 +20,10 @@ import enum
from functools import partial
import operator
import re
import unittest
import subprocess
import sys
import types
import unittest
import warnings
import weakref
import functools
@ -55,6 +57,7 @@ config.parse_flags_with_absl()
FLAGS = config.FLAGS
python_version = (sys.version_info[0], sys.version_info[1])
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
@ -5436,5 +5439,28 @@ class NamedCallTest(jtu.JaxTestCase):
self.assertRaises(OverflowError, f, int_min - 1)
class BackendsTest(jtu.JaxTestCase):
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@unittest.skipIf(python_version < (3, 7), "test requires Python 3.7 or higher")
@jtu.skip_on_devices("gpu", "tpu")
def test_cpu_warning_suppression(self):
warning_expected = (
"import jax; "
"jax.numpy.arange(10)")
warning_not_expected = (
"import jax; "
"jax.config.update('jax_platform_name', 'cpu'); "
"jax.numpy.arange(10)")
result = subprocess.run([sys.executable, '-c', warning_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" in result.stderr.decode()
result = subprocess.run([sys.executable, '-c', warning_not_expected],
check=True, capture_output=True)
assert "No GPU/TPU found" not in result.stderr.decode()
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())