Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets.

PiperOrigin-RevId: 515708165
This commit is contained in:
Peter Hawkins 2023-03-10 12:25:25 -08:00 committed by jax authors
parent 03fc8a4766
commit 0935a7cb31
3 changed files with 45 additions and 17 deletions

View File

@ -99,18 +99,14 @@ py_library_providing_imports_info(
"_src/callback.py",
"_src/checkify.py",
"_src/core.py",
"_src/custom_api_util.py",
"_src/custom_batching.py",
"_src/custom_derivatives.py",
"_src/custom_transpose.py",
"_src/debugging.py",
"_src/deprecations.py",
"_src/device_array.py",
"_src/dispatch.py",
"_src/dlpack.py",
"_src/dtypes.py",
"_src/effects.py",
"_src/environment_info.py",
"_src/errors.py",
"_src/flatten_util.py",
"_src/global_device_array.py",
@ -173,7 +169,11 @@ py_library_providing_imports_info(
visibility = ["//visibility:public"],
deps = [
":cloud_tpu_init",
":custom_api_util",
":config",
":deprecations",
":effects",
":environment_info",
":lazy_loader",
":monitoring",
":path",
@ -201,6 +201,31 @@ pytype_library(
],
)
pytype_library(
name = "custom_api_util",
srcs = ["_src/custom_api_util.py"],
)
pytype_library(
name = "deprecations",
srcs = ["_src/deprecations.py"],
)
pytype_library(
name = "effects",
srcs = ["_src/effects.py"],
)
pytype_library(
name = "environment_info",
srcs = ["_src/environment_info.py"],
deps = [
":xla_bridge",
":version",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_library(
name = "iree",
srcs = ["_src/iree.py"],

View File

@ -95,7 +95,7 @@ class JaxException(Exception):
return cls(metadata)
def get_effect_type(self) -> core.Effect:
pass
raise NotImplementedError
@functools.total_ordering

View File

@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax._src import lib
import numpy as np
import subprocess
import sys
import textwrap
from typing import Optional, Union
from jax import version
from jax._src import lib
from jax._src import xla_bridge
import numpy as np
def try_nvidia_smi() -> Optional[str]:
try:
return subprocess.check_output(['nvidia-smi']).decode()
@ -33,23 +34,25 @@ def print_environment_info(return_string: bool = False) -> Union[None, str]:
This is useful information to include when asking a question or filing a bug.
Args:
return_string (bool) : if True, return the string rather than printing to stdout.
Args: return_string (bool) : if True, return the string rather than printing
to stdout.
"""
# TODO(jakevdp): should we include other info, e.g. jax.config.values?
python_version = sys.version.replace('\n', ' ')
with np.printoptions(threshold=4, edgeitems=2):
devices_short = str(np.array(jax.devices())).replace('\n', '')
info = textwrap.dedent(f"""\
jax: {jax.__version__}
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
info = textwrap.dedent(
f"""\
jax: {version.__version__}
jaxlib: {lib.version_str}
numpy: {np.__version__}
python: {python_version}
jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short}
process_count: {jax.process_count()}""")
jax.devices ({xla_bridge.device_count()} total, {xla_bridge.local_device_count()} local): {devices_short}
process_count: {xla_bridge.process_count()}"""
)
nvidia_smi = try_nvidia_smi()
if nvidia_smi:
info += "\n\n$ nvidia-smi\n" + nvidia_smi
info += '\n\n$ nvidia-smi\n' + nvidia_smi
if return_string:
return info
else: