From 0fb462efd716013d3c04eb5ae1f5da02fab9203a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Sep 2022 15:39:33 -0700 Subject: [PATCH] Add jax.print_environment_info() --- docs/jax.rst | 8 ++++++ jax/__init__.py | 1 + jax/_src/environment_info.py | 56 ++++++++++++++++++++++++++++++++++++ tests/api_test.py | 16 ++++++++++- 4 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 jax/_src/environment_info.py diff --git a/docs/jax.rst b/docs/jax.rst index e4f8aa4dc..2f6442958 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -109,3 +109,11 @@ Callbacks pure_callback debug.callback + +Miscellaneous +------------- + +.. autosummary:: + :toctree: _autosummary + + print_environment_info diff --git a/jax/__init__.py b/jax/__init__.py index 75b54317e..dc0dfcb43 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -57,6 +57,7 @@ from jax._src.config import ( transfer_guard_device_to_host as transfer_guard_device_to_host, ) from .core import eval_context as ensure_compile_time_eval +from jax._src.environment_info import print_environment_info as print_environment_info from jax._src.api import ( ad, # TODO(phawkins): update users to avoid this. effects_barrier, diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py new file mode 100644 index 000000000..5cfc6a3d0 --- /dev/null +++ b/jax/_src/environment_info.py @@ -0,0 +1,56 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +from jax._src import lib +import numpy as np + +import subprocess +import sys +import textwrap +from typing import Optional, Union + +def try_nvidia_smi() -> Optional[str]: + try: + return subprocess.check_output(['nvidia-smi']).decode() + except Exception: + return None + + +def print_environment_info(return_string: bool = False) -> Union[None, str]: + """Returns a string containing local environment & JAX installation information. + + This is useful information to include when asking a question or filing a bug. + + Args: + return_string (bool) : if True, return the string rather than printing to stdout. + """ + # TODO(jakevdp): should we include other info, e.g. jax.config.values? + python_version = sys.version.replace('\n', ' ') + with np.printoptions(threshold=4, edgeitems=2): + devices_short = str(np.array(jax.devices())).replace('\n', '') + info = textwrap.dedent(f"""\ + jax: {jax.__version__} + jaxlib: {lib.version_str} + numpy: {np.__version__} + python: {python_version} + jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short} + process_count: {jax.process_count()}""") + nvidia_smi = try_nvidia_smi() + if nvidia_smi: + info += "\n\n$ nvidia-smi\n" + nvidia_smi + if return_string: + return info + else: + return print(info) diff --git a/tests/api_test.py b/tests/api_test.py index e3201dde7..9413e01a1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -45,7 +45,7 @@ import jax.numpy as jnp from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian from jax import core, lax from jax import custom_batching -from jax._src import api, dtypes, dispatch +from jax._src import api, dtypes, dispatch, lib from jax.core import Primitive from jax.errors import UnexpectedTracerError from jax.interpreters import ad @@ -8998,5 +8998,19 @@ class CleanupTest(jtu.JaxTestCase): assert core.trace_state_clean() +class EnvironmentInfoTest(jtu.JaxTestCase): + @parameterized.parameters([True, False]) + def test_print_environment_info(self, return_string): + with jtu.capture_stdout() as stdout: + result = jax.print_environment_info(return_string=return_string) + if return_string: + self.assertEmpty(stdout()) + else: + self.assertIsNone(result) + result = stdout() + assert f"jax: {jax.__version__}" in result + assert f"jaxlib: {lib.version_str}" in result + assert f"numpy: {np.__version__}" in result + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())