mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow suppression of GPU warning via jax_platform_name
This commit is contained in:
parent
0d68dbd619
commit
c8e571ad84
@ -225,13 +225,12 @@ def backends():
|
||||
# we expect a RuntimeError.
|
||||
logging.info("Unable to initialize backend '%s': %s" % (name, err))
|
||||
continue
|
||||
if _default_backend.platform == "cpu":
|
||||
if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu':
|
||||
logging.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
return _backends
|
||||
|
||||
|
||||
|
||||
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
|
||||
def get_backend(platform=None):
|
||||
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
||||
|
@ -20,8 +20,10 @@ import enum
|
||||
from functools import partial
|
||||
import operator
|
||||
import re
|
||||
import unittest
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
import functools
|
||||
@ -55,6 +57,7 @@ config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
python_version = (sys.version_info[0], sys.version_info[1])
|
||||
numpy_version = tuple(map(int, np.__version__.split('.')[:3]))
|
||||
|
||||
|
||||
@ -5436,5 +5439,28 @@ class NamedCallTest(jtu.JaxTestCase):
|
||||
self.assertRaises(OverflowError, f, int_min - 1)
|
||||
|
||||
|
||||
class BackendsTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
||||
@unittest.skipIf(python_version < (3, 7), "test requires Python 3.7 or higher")
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def test_cpu_warning_suppression(self):
|
||||
warning_expected = (
|
||||
"import jax; "
|
||||
"jax.numpy.arange(10)")
|
||||
warning_not_expected = (
|
||||
"import jax; "
|
||||
"jax.config.update('jax_platform_name', 'cpu'); "
|
||||
"jax.numpy.arange(10)")
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', warning_expected],
|
||||
check=True, capture_output=True)
|
||||
assert "No GPU/TPU found" in result.stderr.decode()
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', warning_not_expected],
|
||||
check=True, capture_output=True)
|
||||
assert "No GPU/TPU found" not in result.stderr.decode()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user