mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split _src files custom_api_util, deprecations, effects and environment_info into separate Bazel targets.
PiperOrigin-RevId: 515708165
This commit is contained in:
parent
03fc8a4766
commit
0935a7cb31
33
jax/BUILD
33
jax/BUILD
@ -99,18 +99,14 @@ py_library_providing_imports_info(
|
||||
"_src/callback.py",
|
||||
"_src/checkify.py",
|
||||
"_src/core.py",
|
||||
"_src/custom_api_util.py",
|
||||
"_src/custom_batching.py",
|
||||
"_src/custom_derivatives.py",
|
||||
"_src/custom_transpose.py",
|
||||
"_src/debugging.py",
|
||||
"_src/deprecations.py",
|
||||
"_src/device_array.py",
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/dtypes.py",
|
||||
"_src/effects.py",
|
||||
"_src/environment_info.py",
|
||||
"_src/errors.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/global_device_array.py",
|
||||
@ -173,7 +169,11 @@ py_library_providing_imports_info(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cloud_tpu_init",
|
||||
":custom_api_util",
|
||||
":config",
|
||||
":deprecations",
|
||||
":effects",
|
||||
":environment_info",
|
||||
":lazy_loader",
|
||||
":monitoring",
|
||||
":path",
|
||||
@ -201,6 +201,31 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "custom_api_util",
|
||||
srcs = ["_src/custom_api_util.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "deprecations",
|
||||
srcs = ["_src/deprecations.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "effects",
|
||||
srcs = ["_src/effects.py"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "environment_info",
|
||||
srcs = ["_src/environment_info.py"],
|
||||
deps = [
|
||||
":xla_bridge",
|
||||
":version",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "iree",
|
||||
srcs = ["_src/iree.py"],
|
||||
|
@ -95,7 +95,7 @@ class JaxException(Exception):
|
||||
return cls(metadata)
|
||||
|
||||
def get_effect_type(self) -> core.Effect:
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@functools.total_ordering
|
||||
|
@ -12,15 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from jax._src import lib
|
||||
import numpy as np
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Optional, Union
|
||||
|
||||
from jax import version
|
||||
from jax._src import lib
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
|
||||
def try_nvidia_smi() -> Optional[str]:
|
||||
try:
|
||||
return subprocess.check_output(['nvidia-smi']).decode()
|
||||
@ -33,23 +34,25 @@ def print_environment_info(return_string: bool = False) -> Union[None, str]:
|
||||
|
||||
This is useful information to include when asking a question or filing a bug.
|
||||
|
||||
Args:
|
||||
return_string (bool) : if True, return the string rather than printing to stdout.
|
||||
Args: return_string (bool) : if True, return the string rather than printing
|
||||
to stdout.
|
||||
"""
|
||||
# TODO(jakevdp): should we include other info, e.g. jax.config.values?
|
||||
python_version = sys.version.replace('\n', ' ')
|
||||
with np.printoptions(threshold=4, edgeitems=2):
|
||||
devices_short = str(np.array(jax.devices())).replace('\n', '')
|
||||
info = textwrap.dedent(f"""\
|
||||
jax: {jax.__version__}
|
||||
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
|
||||
info = textwrap.dedent(
|
||||
f"""\
|
||||
jax: {version.__version__}
|
||||
jaxlib: {lib.version_str}
|
||||
numpy: {np.__version__}
|
||||
python: {python_version}
|
||||
jax.devices ({jax.device_count()} total, {jax.local_device_count()} local): {devices_short}
|
||||
process_count: {jax.process_count()}""")
|
||||
jax.devices ({xla_bridge.device_count()} total, {xla_bridge.local_device_count()} local): {devices_short}
|
||||
process_count: {xla_bridge.process_count()}"""
|
||||
)
|
||||
nvidia_smi = try_nvidia_smi()
|
||||
if nvidia_smi:
|
||||
info += "\n\n$ nvidia-smi\n" + nvidia_smi
|
||||
info += '\n\n$ nvidia-smi\n' + nvidia_smi
|
||||
if return_string:
|
||||
return info
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user