mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add jax.print_environment_info()
This commit is contained in:
parent
00636617c0
commit
0fb462efd7
@ -109,3 +109,11 @@ Callbacks
|
||||
|
||||
pure_callback
|
||||
debug.callback
|
||||
|
||||
Miscellaneous
|
||||
-------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
print_environment_info
|
||||
|
@ -57,6 +57,7 @@ from jax._src.config import (
|
||||
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.environment_info import print_environment_info as print_environment_info
|
||||
from jax._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
effects_barrier,
|
||||
|
56
jax/_src/environment_info.py
Normal file
56
jax/_src/environment_info.py
Normal file
@ -0,0 +1,56 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from jax._src import lib
|
||||
import numpy as np
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Optional, Union
|
||||
|
||||
def try_nvidia_smi() -> Optional[str]:
|
||||
try:
|
||||
return subprocess.check_output(['nvidia-smi']).decode()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def print_environment_info(return_string: bool = False) -> Union[None, str]:
|
||||
"""Returns a string containing local environment & JAX installation information.
|
||||
|
||||
This is useful information to include when asking a question or filing a bug.
|
||||
|
||||
Args:
|
||||
return_string (bool) : if True, return the string rather than printing to stdout.
|
||||
"""
|
||||
# TODO(jakevdp): should we include other info, e.g. jax.config.values?
|
||||
python_version = sys.version.replace('\n', ' ')
|
||||
with np.printoptions(threshold=4, edgeitems=2):
|
||||
devices_short = str(np.array(jax.devices())).replace('\n', '')
|
||||
info = textwrap.dedent(f"""\
|
||||
jax: {jax.__version__}
|
||||
jaxlib: {lib.version_str}
|
||||
numpy: {np.__version__}
|
||||
python: {python_version}
|
||||
jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short}
|
||||
process_count: {jax.process_count()}""")
|
||||
nvidia_smi = try_nvidia_smi()
|
||||
if nvidia_smi:
|
||||
info += "\n\n$ nvidia-smi\n" + nvidia_smi
|
||||
if return_string:
|
||||
return info
|
||||
else:
|
||||
return print(info)
|
@ -45,7 +45,7 @@ import jax.numpy as jnp
|
||||
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
|
||||
from jax import core, lax
|
||||
from jax import custom_batching
|
||||
from jax._src import api, dtypes, dispatch
|
||||
from jax._src import api, dtypes, dispatch, lib
|
||||
from jax.core import Primitive
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
@ -8998,5 +8998,19 @@ class CleanupTest(jtu.JaxTestCase):
|
||||
assert core.trace_state_clean()
|
||||
|
||||
|
||||
class EnvironmentInfoTest(jtu.JaxTestCase):
|
||||
@parameterized.parameters([True, False])
|
||||
def test_print_environment_info(self, return_string):
|
||||
with jtu.capture_stdout() as stdout:
|
||||
result = jax.print_environment_info(return_string=return_string)
|
||||
if return_string:
|
||||
self.assertEmpty(stdout())
|
||||
else:
|
||||
self.assertIsNone(result)
|
||||
result = stdout()
|
||||
assert f"jax: {jax.__version__}" in result
|
||||
assert f"jaxlib: {lib.version_str}" in result
|
||||
assert f"numpy: {np.__version__}" in result
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user