Add jax.print_environment_info()

This commit is contained in:
Jake VanderPlas 2022-09-12 15:39:33 -07:00
parent 00636617c0
commit 0fb462efd7
4 changed files with 80 additions and 1 deletions

View File

@ -109,3 +109,11 @@ Callbacks
pure_callback
debug.callback
Miscellaneous
-------------
.. autosummary::
:toctree: _autosummary
print_environment_info

View File

@ -57,6 +57,7 @@ from jax._src.config import (
transfer_guard_device_to_host as transfer_guard_device_to_host,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
from jax._src.api import (
ad, # TODO(phawkins): update users to avoid this.
effects_barrier,

View File

@ -0,0 +1,56 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax._src import lib
import numpy as np
import subprocess
import sys
import textwrap
from typing import Optional, Union
def try_nvidia_smi() -> Optional[str]:
try:
return subprocess.check_output(['nvidia-smi']).decode()
except Exception:
return None
def print_environment_info(return_string: bool = False) -> Union[None, str]:
"""Returns a string containing local environment & JAX installation information.
This is useful information to include when asking a question or filing a bug.
Args:
return_string (bool) : if True, return the string rather than printing to stdout.
"""
# TODO(jakevdp): should we include other info, e.g. jax.config.values?
python_version = sys.version.replace('\n', ' ')
with np.printoptions(threshold=4, edgeitems=2):
devices_short = str(np.array(jax.devices())).replace('\n', '')
info = textwrap.dedent(f"""\
jax: {jax.__version__}
jaxlib: {lib.version_str}
numpy: {np.__version__}
python: {python_version}
jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short}
process_count: {jax.process_count()}""")
nvidia_smi = try_nvidia_smi()
if nvidia_smi:
info += "\n\n$ nvidia-smi\n" + nvidia_smi
if return_string:
return info
else:
return print(info)

View File

@ -45,7 +45,7 @@ import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import core, lax
from jax import custom_batching
from jax._src import api, dtypes, dispatch
from jax._src import api, dtypes, dispatch, lib
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
@ -8998,5 +8998,19 @@ class CleanupTest(jtu.JaxTestCase):
assert core.trace_state_clean()
class EnvironmentInfoTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_print_environment_info(self, return_string):
with jtu.capture_stdout() as stdout:
result = jax.print_environment_info(return_string=return_string)
if return_string:
self.assertEmpty(stdout())
else:
self.assertIsNone(result)
result = stdout()
assert f"jax: {jax.__version__}" in result
assert f"jaxlib: {lib.version_str}" in result
assert f"numpy: {np.__version__}" in result
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())