From 0935a7cb3181125beb91735e339c7aa2728a4f78 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 10 Mar 2023 12:25:25 -0800 Subject: [PATCH] Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets. PiperOrigin-RevId: 515708165 --- jax/BUILD | 33 +++++++++++++++++++++++++++++---- jax/_src/checkify.py | 2 +- jax/_src/environment_info.py | 27 +++++++++++++++------------ 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index f7584e29c..dc8b13918 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -99,18 +99,14 @@ py_library_providing_imports_info( "_src/callback.py", "_src/checkify.py", "_src/core.py", - "_src/custom_api_util.py", "_src/custom_batching.py", "_src/custom_derivatives.py", "_src/custom_transpose.py", "_src/debugging.py", - "_src/deprecations.py", "_src/device_array.py", "_src/dispatch.py", "_src/dlpack.py", "_src/dtypes.py", - "_src/effects.py", - "_src/environment_info.py", "_src/errors.py", "_src/flatten_util.py", "_src/global_device_array.py", @@ -173,7 +169,11 @@ py_library_providing_imports_info( visibility = ["//visibility:public"], deps = [ ":cloud_tpu_init", + ":custom_api_util", ":config", + ":deprecations", + ":effects", + ":environment_info", ":lazy_loader", ":monitoring", ":path", @@ -201,6 +201,31 @@ pytype_library( ], ) +pytype_library( + name = "custom_api_util", + srcs = ["_src/custom_api_util.py"], +) + +pytype_library( + name = "deprecations", + srcs = ["_src/deprecations.py"], +) + +pytype_library( + name = "effects", + srcs = ["_src/effects.py"], +) + +pytype_library( + name = "environment_info", + srcs = ["_src/environment_info.py"], + deps = [ + ":xla_bridge", + ":version", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_library( name = "iree", srcs = ["_src/iree.py"], diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5449a2877..9f9b3e119 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -95,7 +95,7 @@ class JaxException(Exception): return cls(metadata) def get_effect_type(self) -> core.Effect: - pass + raise NotImplementedError @functools.total_ordering diff --git a/jax/_src/environment_info.py b/jax/_src/environment_info.py index 2c1c07d3d..edf5933ea 100644 --- a/jax/_src/environment_info.py +++ b/jax/_src/environment_info.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax -from jax._src import lib -import numpy as np - import subprocess import sys import textwrap from typing import Optional, Union +from jax import version +from jax._src import lib +from jax._src import xla_bridge +import numpy as np + def try_nvidia_smi() -> Optional[str]: try: return subprocess.check_output(['nvidia-smi']).decode() @@ -33,23 +34,25 @@ def print_environment_info(return_string: bool = False) -> Union[None, str]: This is useful information to include when asking a question or filing a bug. - Args: - return_string (bool) : if True, return the string rather than printing to stdout. + Args: return_string (bool) : if True, return the string rather than printing + to stdout. """ # TODO(jakevdp): should we include other info, e.g. jax.config.values? python_version = sys.version.replace('\n', ' ') with np.printoptions(threshold=4, edgeitems=2): - devices_short = str(np.array(jax.devices())).replace('\n', '') - info = textwrap.dedent(f"""\ - jax: {jax.__version__} + devices_short = str(np.array(xla_bridge.devices())).replace('\n', '') + info = textwrap.dedent( + f"""\ + jax: {version.__version__} jaxlib: {lib.version_str} numpy: {np.__version__} python: {python_version} - jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short} - process_count: {jax.process_count()}""") + jax.devices ({xla_bridge.device_count()} total, {xla_bridge.local_device_count()} local): {devices_short} + process_count: {xla_bridge.process_count()}""" + ) nvidia_smi = try_nvidia_smi() if nvidia_smi: - info += "\n\n$ nvidia-smi\n" + nvidia_smi + info += '\n\n$ nvidia-smi\n' + nvidia_smi if return_string: return info else: